| from langchain.tools import BaseTool |
| from PIL import Image, ImageDraw |
| import requests |
| from dotenv import load_dotenv |
| import os |
| load_dotenv() |
|
|
|
|
| def object_detection_query(filepath): |
| API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50" |
| headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']} |
| with open(filepath, "rb") as f: |
| data = f.read() |
| response = requests.post(API_URL, headers=headers, data=data) |
| return response.json() |
|
|
| def bounding_box(filepath): |
| |
| output = object_detection_query(filepath) |
|
|
| |
| image = Image.open(filepath).convert('RGB') |
|
|
| |
| draw = ImageDraw.Draw(image) |
|
|
| |
| for detection in output: |
| label = detection['label'] |
| score = detection['score'] |
| box = detection['box'] |
|
|
| |
| draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline="red", width=2) |
|
|
| |
| text = f"{label} ({score:.2f})" |
| draw.text((box['xmin'], box['ymin']-10), text, fill='red') |
|
|
| return image |
|
|
| def captioning_query(filepath): |
| API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" |
| headers = {"Authorization": "Bearer " + os.environ['HUGGINGFACEHUB_API_TOKEN']} |
| with open(filepath, "rb") as f: |
| data = f.read() |
| response = requests.post(API_URL, headers=headers, data=data) |
| return response.json() |
|
|
| class ImageCaptionTools(BaseTool): |
| name = "Image_Caption_Tools" |
| description = "Use this tool with any given image path to receive a personalized description, poem, story, or more. "\ |
| "Ideal for agents seeking tailored insights. "\ |
| "Let the tool craft content based on your image for a unique perspective." |
|
|
| def _run(self, image_path) -> str: |
| """Use the tool.""" |
| result = captioning_query(image_path) |
| text = result[0]['generated_text'] |
| return text |
|
|
| async def _arun(self, query: str) -> str: |
| """Use the tool asynchronously.""" |
| raise NotImplementedError("custom_search does not support async") |
| |
|
|
| class ObjectDetectionTool(BaseTool): |
| name = "Object_Detection_Tool" |
| description = "Object Detection Tool: Use this tool to detect objects in an image. Provide the image path, " \ |
| "and it will return a list of detected objects. Each element in the list is in the format: " \ |
| "[x1, y1, x2, y2] class_name confidence_score. This tool focuses on object detection, providing " \ |
| "locations of objects in the image. For image descriptions or other insights, explore additional tools." |
|
|
| def _run(self, image_path) -> str: |
| """Use the tool.""" |
| results = object_detection_query(image_path) |
| detections = "" |
| for result in results: |
| box = result['box'] |
| detections += '[{}, {}, {}, {}]'.format(int(box['xmin']), int(box['ymin']), int(box['xmax']), int(box['ymax'])) |
| detections += ' {}'.format(result['label']) |
| detections += ' {}\n'.format(result['score']) |
| return detections |
|
|
| async def _arun(self, query: str) -> str: |
| """Use the tool asynchronously.""" |
| raise NotImplementedError("custom_search does not support async") |
|
|
|
|