| import torch |
| import torchvision |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import gradio as gr |
|
|
| |
| model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) |
| model.eval() |
|
|
| |
| COCO_INSTANCE_CATEGORY_NAMES = [ |
| '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
| 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
| 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
| 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', |
| 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', |
| 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', |
| 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', |
| 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', |
| 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', |
| 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', |
| 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', |
| 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', |
| 'hair drier', 'toothbrush' |
| ] |
|
|
| |
| def segment_objects(image, threshold=0.5): |
| transform = torchvision.transforms.ToTensor() |
| img_tensor = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(img_tensor)[0] |
|
|
| masks = output['masks'] |
| boxes = output['boxes'] |
| labels = output['labels'] |
| scores = output['scores'] |
|
|
| image_np = np.array(image).copy() |
| fig, ax = plt.subplots(1, figsize=(10, 10)) |
| ax.imshow(image_np) |
|
|
| for i in range(len(masks)): |
| if scores[i] >= threshold: |
| mask = masks[i, 0].cpu().numpy() |
| mask = mask > 0.5 |
|
|
| |
| color = np.random.rand(3) |
| colored_mask = np.zeros_like(image_np, dtype=np.uint8) |
| for c in range(3): |
| colored_mask[:, :, c] = mask * int(color[c] * 255) |
|
|
| |
| image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8) |
|
|
| |
| x1, y1, x2, y2 = boxes[i].cpu().numpy() |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, |
| fill=False, color=color, linewidth=2)) |
| label = COCO_INSTANCE_CATEGORY_NAMES[labels[i].item()] |
| ax.text(x1, y1, f"{label}: {scores[i]:.2f}", |
| bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) |
|
|
| ax.imshow(image_np) |
| ax.axis('off') |
| output_path = "output_maskrcnn_with_masks.png" |
| plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
| plt.close() |
| return output_path |
|
|
| |
| interface = gr.Interface( |
| fn=segment_objects, |
| inputs=[ |
| gr.Image(type="pil", label="Upload Image"), |
| gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
| ], |
| outputs=gr.Image(type="filepath", label="Segmented Output"), |
| title="Mask R-CNN Instance Segmentation", |
| description="Upload an image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)." |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch(debug=True) |
|
|