Instructions to use joshieyu/Omniparser_FT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use joshieyu/Omniparser_FT with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("joshieyu/Omniparser_FT", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import base64 | |
| import io | |
| from typing import Any, Dict, List, Literal, Optional, Tuple, Union | |
| import cv2 | |
| import easyocr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from PIL.Image import Image as ImageType | |
| from supervision.detection.core import Detections | |
| from supervision.draw.color import Color, ColorPalette | |
| from torchvision.ops import box_convert | |
| from torchvision.transforms import ToPILImage | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from transformers.image_utils import load_image | |
| from ultralytics import YOLO | |
| # NOTE: here so that it's downloaded before hand so that the endpoint it not stuck listening, whilst the required | |
| # files are still being downloaded | |
| easyocr.Reader(["en"]) | |
| class EndpointHandler: | |
| def __init__(self, model_dir: str = "/repository") -> None: | |
| self.device = ( | |
| torch.device("cuda") if torch.cuda.is_available() | |
| else (torch.device("mps") if torch.backends.mps.is_available() | |
| else torch.device("cpu")) | |
| ) | |
| # bounding box detection model | |
| self.yolo = YOLO(f"{model_dir}/icon_detect/model.pt") | |
| # captioning model | |
| self.processor = AutoProcessor.from_pretrained( | |
| "microsoft/Florence-2-base", trust_remote_code=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| f"{model_dir}/icon_caption", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| ).to(self.device) | |
| # ocr | |
| self.ocr = easyocr.Reader(["en"]) | |
| # box annotator | |
| self.annotator = BoxAnnotator() | |
| def __call__(self, data: Dict[str, Any]) -> Any: | |
| # data should contain the following: | |
| # "inputs": { | |
| # "image": url/base64, | |
| # (optional) "image_size": {"w": int, "h": int}, | |
| # (optional) "bbox_threshold": float, | |
| # (optional) "iou_threshold": float, | |
| # } | |
| data = data.pop("inputs") | |
| # read image from either url or base64 encoding | |
| image = load_image(data["image"]) | |
| ocr_texts, ocr_bboxes = self.check_ocr_bboxes( | |
| image, | |
| out_format="xyxy", | |
| ocr_kwargs={"text_threshold": 0.8}, | |
| ) | |
| annotated_image, filtered_bboxes_out = self.get_som_labeled_img( | |
| image, | |
| image_size=data.get("image_size", None), | |
| ocr_texts=ocr_texts, | |
| ocr_bboxes=ocr_bboxes, | |
| bbox_threshold=data.get("bbox_threshold", 0.05), | |
| iou_threshold=data.get("iou_threshold", None), | |
| ) | |
| return { | |
| "image": annotated_image, | |
| "bboxes": filtered_bboxes_out, | |
| } | |
| def check_ocr_bboxes( | |
| self, | |
| image: ImageType, | |
| out_format: Literal["xywh", "xyxy"] = "xywh", | |
| ocr_kwargs: Optional[Dict[str, Any]] = {}, | |
| ) -> Tuple[List[str], List[List[int]]]: | |
| if image.mode == "RBGA": | |
| image = image.convert("RGB") | |
| result = self.ocr.readtext(np.array(image), **ocr_kwargs) # type: ignore | |
| texts = [str(item[1]) for item in result] | |
| bboxes = [ | |
| self.coordinates_to_bbox(item[0], format=out_format) for item in result | |
| ] | |
| return (texts, bboxes) | |
| def coordinates_to_bbox( | |
| coordinates: np.ndarray, format: Literal["xywh", "xyxy"] = "xywh" | |
| ) -> List[int]: | |
| match format: | |
| case "xywh": | |
| return [ | |
| int(coordinates[0][0]), | |
| int(coordinates[0][1]), | |
| int(coordinates[2][0] - coordinates[0][0]), | |
| int(coordinates[2][1] - coordinates[0][1]), | |
| ] | |
| case "xyxy": | |
| return [ | |
| int(coordinates[0][0]), | |
| int(coordinates[0][1]), | |
| int(coordinates[2][0]), | |
| int(coordinates[2][1]), | |
| ] | |
| def bbox_area(bbox: List[int], w: int, h: int) -> int: | |
| bbox = [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h] | |
| return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) | |
| def remove_bbox_overlap( | |
| xyxy_bboxes: List[Dict[str, Any]], | |
| ocr_bboxes: Optional[List[Dict[str, Any]]] = None, | |
| iou_threshold: Optional[float] = 0.7, | |
| ) -> List[Dict[str, Any]]: | |
| filtered_bboxes = [] | |
| if ocr_bboxes is not None: | |
| filtered_bboxes.extend(ocr_bboxes) | |
| for i, bbox_outter in enumerate(xyxy_bboxes): | |
| bbox_left = bbox_outter["bbox"] | |
| valid_bbox = True | |
| for j, bbox_inner in enumerate(xyxy_bboxes): | |
| if i == j: | |
| continue | |
| bbox_right = bbox_inner["bbox"] | |
| if ( | |
| intersection_over_union( | |
| bbox_left, | |
| bbox_right, | |
| ) | |
| > iou_threshold # type: ignore | |
| ) and (area(bbox_left) > area(bbox_right)): | |
| valid_bbox = False | |
| break | |
| if valid_bbox is False: | |
| continue | |
| if ocr_bboxes is None: | |
| filtered_bboxes.append(bbox_outter) | |
| continue | |
| box_added = False | |
| ocr_labels = [] | |
| for ocr_bbox in ocr_bboxes: | |
| if not box_added: | |
| bbox_right = ocr_bbox["bbox"] | |
| if overlap(bbox_right, bbox_left): | |
| try: | |
| ocr_labels.append(ocr_bbox["content"]) | |
| filtered_bboxes.remove(ocr_bbox) | |
| except Exception: | |
| continue | |
| elif overlap(bbox_left, bbox_right): | |
| box_added = True | |
| break | |
| if not box_added: | |
| filtered_bboxes.append( | |
| { | |
| "type": "icon", | |
| "bbox": bbox_outter["bbox"], | |
| "interactivity": True, | |
| "content": " ".join(ocr_labels) if ocr_labels else None, | |
| } | |
| ) | |
| return filtered_bboxes | |
| def get_som_labeled_img( | |
| self, | |
| image: ImageType, | |
| image_size: Optional[Dict[Literal["w", "h"], int]] = None, | |
| ocr_texts: Optional[List[str]] = None, | |
| ocr_bboxes: Optional[List[List[int]]] = None, | |
| bbox_threshold: float = 0.01, | |
| iou_threshold: Optional[float] = None, | |
| caption_prompt: Optional[str] = None, | |
| caption_batch_size: int = 64, # ~2GiB of GPU VRAM (can be increased to 128 which is ~4GiB of GPU VRAM) | |
| ) -> Tuple[str, List[Dict[str, Any]]]: | |
| if image.mode == "RBGA": | |
| image = image.convert("RGB") | |
| w, h = image.size | |
| if image_size is None: | |
| imgsz = {"h": h, "w": w} | |
| else: | |
| imgsz = [image_size.get("h", h), image_size.get("w", w)] | |
| out = self.yolo.predict( | |
| image, | |
| imgsz=imgsz, | |
| conf=bbox_threshold, | |
| iou=iou_threshold or 0.7, | |
| verbose=False, | |
| )[0] | |
| if out.boxes is None: | |
| raise RuntimeError( | |
| "YOLO prediction failed to produce the bounding boxes..." | |
| ) | |
| xyxy_bboxes = out.boxes.xyxy | |
| xyxy_bboxes = xyxy_bboxes / torch.Tensor([w, h, w, h]).to(xyxy_bboxes.device) | |
| image_np = np.asarray(image) # type: ignore | |
| if ocr_bboxes: | |
| ocr_bboxes = torch.tensor(ocr_bboxes) / torch.Tensor([w, h, w, h]) # type: ignore | |
| ocr_bboxes = ocr_bboxes.tolist() # type: ignore | |
| ocr_bboxes = [ | |
| { | |
| "type": "text", | |
| "bbox": bbox, | |
| "interactivity": False, | |
| "content": text, | |
| "source": "box_ocr_content_ocr", | |
| } | |
| for bbox, text in zip(ocr_bboxes, ocr_texts) # type: ignore | |
| if self.bbox_area(bbox, w, h) > 0 | |
| ] | |
| xyxy_bboxes = [ | |
| { | |
| "type": "icon", | |
| "bbox": bbox, | |
| "interactivity": True, | |
| "content": None, | |
| "source": "box_yolo_content_yolo", | |
| } | |
| for bbox in xyxy_bboxes.tolist() | |
| if self.bbox_area(bbox, w, h) > 0 | |
| ] | |
| filtered_bboxes = self.remove_bbox_overlap( | |
| xyxy_bboxes=xyxy_bboxes, | |
| ocr_bboxes=ocr_bboxes, # type: ignore | |
| iou_threshold=iou_threshold or 0.7, | |
| ) | |
| filtered_bboxes_out = sorted( | |
| filtered_bboxes, key=lambda x: x["content"] is None | |
| ) | |
| starting_idx = next( | |
| ( | |
| idx | |
| for idx, bbox in enumerate(filtered_bboxes_out) | |
| if bbox["content"] is None | |
| ), | |
| -1, | |
| ) | |
| filtered_bboxes = torch.tensor([box["bbox"] for box in filtered_bboxes_out]) | |
| non_ocr_bboxes = filtered_bboxes[starting_idx:] | |
| bbox_images = [] | |
| for _, coordinates in enumerate(non_ocr_bboxes): | |
| try: | |
| xmin, xmax = ( | |
| int(coordinates[0] * image_np.shape[1]), | |
| int(coordinates[2] * image_np.shape[1]), | |
| ) | |
| ymin, ymax = ( | |
| int(coordinates[1] * image_np.shape[0]), | |
| int(coordinates[3] * image_np.shape[0]), | |
| ) | |
| cropped_image = image_np[ymin:ymax, xmin:xmax, :] | |
| cropped_image = cv2.resize(cropped_image, (64, 64)) | |
| bbox_images.append(ToPILImage()(cropped_image)) | |
| except Exception: | |
| continue | |
| if caption_prompt is None: | |
| caption_prompt = "<CAPTION>" | |
| captions = [] | |
| for idx in range(0, len(bbox_images), caption_batch_size): # type: ignore | |
| batch = bbox_images[idx : idx + caption_batch_size] # type: ignore | |
| inputs = self.processor( | |
| images=batch, | |
| text=[caption_prompt] * len(batch), | |
| return_tensors="pt", | |
| do_resize=False, | |
| ) | |
| if self.device.type in {"cuda", "mps"}: | |
| inputs = inputs.to(device=self.device, dtype=torch.float16) | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=20, | |
| num_beams=1, | |
| do_sample=False, | |
| early_stopping=False, | |
| ) | |
| generated_texts = self.processor.batch_decode( | |
| generated_ids, skip_special_tokens=True | |
| ) | |
| captions.extend([text.strip() for text in generated_texts]) | |
| ocr_texts = [f"Text Box ID {idx}: {text}" for idx, text in enumerate(ocr_texts)] # type: ignore | |
| for _, bbox in enumerate(filtered_bboxes_out): | |
| if bbox["content"] is None: | |
| bbox["content"] = captions.pop(0) | |
| filtered_bboxes = box_convert( | |
| boxes=filtered_bboxes, in_fmt="xyxy", out_fmt="cxcywh" | |
| ) | |
| annotated_image = image_np.copy() | |
| bboxes_annotate = filtered_bboxes * torch.Tensor([w, h, w, h]) | |
| xyxy_annotate = box_convert( | |
| bboxes_annotate, in_fmt="cxcywh", out_fmt="xyxy" | |
| ).numpy() | |
| detections = Detections(xyxy=xyxy_annotate) | |
| labels = [str(idx) for idx in range(bboxes_annotate.shape[0])] | |
| annotated_image = self.annotator.annotate( | |
| scene=annotated_image, | |
| detections=detections, | |
| labels=labels, | |
| image_size=(w, h), | |
| ) | |
| assert w == annotated_image.shape[1] and h == annotated_image.shape[0] | |
| out_image = Image.fromarray(annotated_image) | |
| out_buffer = io.BytesIO() | |
| out_image.save(out_buffer, format="PNG") | |
| encoded_image = base64.b64encode(out_buffer.getvalue()).decode("ascii") | |
| return encoded_image, filtered_bboxes_out | |
| def area(bbox: List[int]) -> int: | |
| return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) | |
| def intersection_area(bbox_left: List[int], bbox_right: List[int]) -> int: | |
| return max( | |
| 0, min(bbox_left[2], bbox_right[2]) - min(bbox_left[0], bbox_right[0]) | |
| ) * max(0, min(bbox_left[3], bbox_right[3]) - min(bbox_left[1], bbox_right[1])) | |
| def intersection_over_union(bbox_left: List[int], bbox_right: List[int]) -> float: | |
| intersection = intersection_area(bbox_left, bbox_right) | |
| bbox_left_area = area(bbox_left) | |
| bbox_right_area = area(bbox_right) | |
| union = bbox_left_area + bbox_right_area - intersection + 1e-6 | |
| ratio_left, ratio_right = 0, 0 | |
| if bbox_left_area > 0 and bbox_right_area > 0: | |
| ratio_left = intersection / bbox_left_area | |
| ratio_right = intersection / bbox_right_area | |
| return max(intersection / union, ratio_left, ratio_right) | |
| def overlap(bbox_left: List[int], bbox_right: List[int]) -> bool: | |
| intersection = intersection_area(bbox_left, bbox_right) | |
| ratio_left = intersection / area(bbox_left) | |
| return ratio_left > 0.80 | |
| class BoxAnnotator: | |
| def __init__( | |
| self, | |
| color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, # type: ignore | |
| thickness: int = 3, | |
| text_color: Color = Color.BLACK, # type: ignore | |
| text_scale: float = 0.5, | |
| text_thickness: int = 2, | |
| text_padding: int = 10, | |
| avoid_overlap: bool = True, | |
| ): | |
| self.color: Union[Color, ColorPalette] = color | |
| self.thickness: int = thickness | |
| self.text_color: Color = text_color | |
| self.text_scale: float = text_scale | |
| self.text_thickness: int = text_thickness | |
| self.text_padding: int = text_padding | |
| self.avoid_overlap: bool = avoid_overlap | |
| def annotate( | |
| self, | |
| scene: np.ndarray, | |
| detections: Detections, | |
| labels: Optional[List[str]] = None, | |
| skip_label: bool = False, | |
| image_size: Optional[Tuple[int, int]] = None, | |
| ) -> np.ndarray: | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| for i in range(len(detections)): | |
| x1, y1, x2, y2 = detections.xyxy[i].astype(int) | |
| class_id = ( | |
| detections.class_id[i] if detections.class_id is not None else None | |
| ) | |
| idx = class_id if class_id is not None else i | |
| color = ( | |
| self.color.by_idx(idx) | |
| if isinstance(self.color, ColorPalette) | |
| else self.color | |
| ) | |
| cv2.rectangle( | |
| img=scene, | |
| pt1=(x1, y1), | |
| pt2=(x2, y2), | |
| color=color.as_bgr(), | |
| thickness=self.thickness, | |
| ) | |
| if skip_label: | |
| continue | |
| text = ( | |
| f"{class_id}" | |
| if (labels is None or len(detections) != len(labels)) | |
| else labels[i] | |
| ) | |
| text_width, text_height = cv2.getTextSize( | |
| text=text, | |
| fontFace=font, | |
| fontScale=self.text_scale, | |
| thickness=self.text_thickness, | |
| )[0] | |
| if not self.avoid_overlap: | |
| text_x = x1 + self.text_padding | |
| text_y = y1 - self.text_padding | |
| text_background_x1 = x1 | |
| text_background_y1 = y1 - 2 * self.text_padding - text_height | |
| text_background_x2 = x1 + 2 * self.text_padding + text_width | |
| text_background_y2 = y1 | |
| else: | |
| ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) = self.get_optimal_label_pos( | |
| self.text_padding, | |
| text_width, | |
| text_height, | |
| x1, | |
| y1, | |
| x2, | |
| y2, | |
| detections, | |
| image_size, | |
| ) | |
| cv2.rectangle( | |
| img=scene, | |
| pt1=(text_background_x1, text_background_y1), | |
| pt2=(text_background_x2, text_background_y2), | |
| color=color.as_bgr(), | |
| thickness=cv2.FILLED, | |
| ) | |
| box_color = color.as_rgb() | |
| luminance = ( | |
| 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] | |
| ) | |
| text_color = (0, 0, 0) if luminance > 160 else (255, 255, 255) | |
| cv2.putText( | |
| img=scene, | |
| text=text, | |
| org=(text_x, text_y), | |
| fontFace=font, | |
| fontScale=self.text_scale, | |
| color=text_color, | |
| thickness=self.text_thickness, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| return scene | |
| def get_optimal_label_pos( | |
| text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size | |
| ): | |
| def get_is_overlap( | |
| detections, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| image_size, | |
| ): | |
| is_overlap = False | |
| for i in range(len(detections)): | |
| detection = detections.xyxy[i].astype(int) | |
| if ( | |
| intersection_over_union( | |
| [ | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ], | |
| detection, | |
| ) | |
| > 0.3 | |
| ): | |
| is_overlap = True | |
| break | |
| if ( | |
| text_background_x1 < 0 | |
| or text_background_x2 > image_size[0] | |
| or text_background_y1 < 0 | |
| or text_background_y2 > image_size[1] | |
| ): | |
| is_overlap = True | |
| return is_overlap | |
| text_x = x1 + text_padding | |
| text_y = y1 - text_padding | |
| text_background_x1 = x1 | |
| text_background_y1 = y1 - 2 * text_padding - text_height | |
| text_background_x2 = x1 + 2 * text_padding + text_width | |
| text_background_y2 = y1 | |
| is_overlap = get_is_overlap( | |
| detections, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| image_size, | |
| ) | |
| if not is_overlap: | |
| return ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) | |
| text_x = x1 - text_padding - text_width | |
| text_y = y1 + text_padding + text_height | |
| text_background_x1 = x1 - 2 * text_padding - text_width | |
| text_background_y1 = y1 | |
| text_background_x2 = x1 | |
| text_background_y2 = y1 + 2 * text_padding + text_height | |
| is_overlap = get_is_overlap( | |
| detections, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| image_size, | |
| ) | |
| if not is_overlap: | |
| return ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) | |
| text_x = x2 + text_padding | |
| text_y = y1 + text_padding + text_height | |
| text_background_x1 = x2 | |
| text_background_y1 = y1 | |
| text_background_x2 = x2 + 2 * text_padding + text_width | |
| text_background_y2 = y1 + 2 * text_padding + text_height | |
| is_overlap = get_is_overlap( | |
| detections, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| image_size, | |
| ) | |
| if not is_overlap: | |
| return ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) | |
| text_x = x2 - text_padding - text_width | |
| text_y = y1 - text_padding | |
| text_background_x1 = x2 - 2 * text_padding - text_width | |
| text_background_y1 = y1 - 2 * text_padding - text_height | |
| text_background_x2 = x2 | |
| text_background_y2 = y1 | |
| is_overlap = get_is_overlap( | |
| detections, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| image_size, | |
| ) | |
| if not is_overlap: | |
| return ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) | |
| return ( | |
| text_x, | |
| text_y, | |
| text_background_x1, | |
| text_background_y1, | |
| text_background_x2, | |
| text_background_y2, | |
| ) |