| import os |
| import json |
| import time |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| |
| from PIL import Image |
| from qwen_vl_utils import process_vision_info |
| from transformers import ( |
| AutoProcessor, |
| Gemma3ForConditionalGeneration, |
| Qwen2_5_VLForConditionalGeneration, |
| ) |
|
|
| from spaces import GPU |
| import supervision as sv |
|
|
| |
| |
| |
| from huggingface_hub import login |
| hf_token = os.environ.get("HF_TOKEN") |
| login(token=hf_token) |
|
|
|
|
|
|
| model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct" |
| model_gemma_id = "google/gemma-3-4b-it" |
|
|
| |
| model_qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_qwen_id, torch_dtype="auto", device_map="auto" |
| ) |
| min_pixels = 224 * 224 |
| max_pixels = 1024 * 1024 |
| processor_qwen = AutoProcessor.from_pretrained( |
| model_qwen_id, min_pixels=min_pixels, max_pixels=max_pixels |
| ) |
|
|
| |
| model_gemma = Gemma3ForConditionalGeneration.from_pretrained( |
| model_gemma_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" |
| ) |
| processor_gemma = AutoProcessor.from_pretrained(model_gemma_id) |
|
|
|
|
| def extract_model_short_name(model_id): |
| return model_id.split("/")[-1].replace("-", " ").replace("_", " ") |
|
|
|
|
| model_qwen_name = extract_model_short_name(model_qwen_id) |
| model_gemma_name = extract_model_short_name(model_gemma_id) |
|
|
|
|
| def create_annotated_image(image, json_data, height, width): |
| try: |
| |
| if "```json" in json_data: |
| parsed_json_data = json_data.split("```json")[1].split("```")[0] |
| else: |
| parsed_json_data = json_data |
| bbox_data = json.loads(parsed_json_data) |
| except Exception: |
| |
| return image |
|
|
| |
| if not isinstance(bbox_data, list): |
| bbox_data = [bbox_data] |
|
|
|
|
| original_width, original_height = image.size |
| x_scale = original_width / width |
| y_scale = original_height / height |
|
|
| points = [] |
| point_labels = [] |
|
|
| annotated_image = np.array(image.convert("RGB")) |
| detections_exist = False |
|
|
| |
| if any("box_2d" in item for item in bbox_data): |
| detections_exist = True |
| |
| detections = sv.Detections.from_vlm(vlm = sv.VLM.QWEN_2_5_VL, |
| result=json_data, |
| |
| resolution_wh=(width, height)) |
| bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
|
| annotated_image = bounding_box_annotator.annotate( |
| scene=annotated_image, detections=detections |
| ) |
| annotated_image = label_annotator.annotate( |
| scene=annotated_image, detections=detections |
| ) |
|
|
| |
| for item in bbox_data: |
| label = item.get("label", "") |
| if "point_2d" in item: |
| x, y = item["point_2d"] |
| scaled_x = int(x * x_scale) |
| scaled_y = int(y * y_scale) |
| points.append([scaled_x, scaled_y]) |
| point_labels.append(label) |
|
|
| if points: |
| points_array = np.array(points).reshape(1, -1, 2) |
| key_points = sv.KeyPoints(xy=points_array) |
| vertex_annotator = sv.VertexAnnotator(radius=5, color=sv.Color.BLUE) |
| annotated_image = vertex_annotator.annotate( |
| scene=annotated_image, key_points=key_points |
| ) |
|
|
| return Image.fromarray(annotated_image) |
|
|
|
|
| @GPU |
| def detect_qwen(image, prompt): |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
|
|
| t0 = time.perf_counter() |
| text = processor_qwen.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = processor_qwen( |
| text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ).to(model_qwen.device) |
|
|
| generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024) |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids) :] |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| output_text = processor_qwen.batch_decode( |
| generated_ids_trimmed, |
| do_sample=True, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| )[0] |
| elapsed_ms = (time.perf_counter() - t0) * 1_000 |
|
|
| |
| input_height = inputs["image_grid_thw"][0][1] * 14 |
| input_width = inputs["image_grid_thw"][0][2] * 14 |
|
|
| annotated_image = create_annotated_image( |
| image, output_text, input_height, input_width |
| ) |
|
|
| time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms" |
| return annotated_image, output_text, time_taken |
|
|
|
|
| @GPU |
| def detect_gemma(image, prompt): |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
|
|
| t0 = time.perf_counter() |
| inputs = processor_gemma.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt" |
| ).to(model_gemma.device) |
|
|
| input_len = inputs["input_ids"].shape[-1] |
|
|
| with torch.inference_mode(): |
| generation = model_gemma.generate(**inputs, max_new_tokens=1024, do_sample=False) |
| |
| generation_trimmed = generation[0][input_len:] |
| output_text = processor_gemma.decode(generation_trimmed, skip_special_tokens=True) |
| elapsed_ms = (time.perf_counter() - t0) * 1_000 |
|
|
| |
| input_height = 896 |
| input_width = 896 |
|
|
| annotated_image = create_annotated_image( |
| image, output_text, input_height, input_width |
| ) |
|
|
| time_taken = f"**Inference time ({model_gemma_name}):** {elapsed_ms:.0f} ms" |
| return annotated_image, output_text, time_taken |
|
|
|
|
| def detect(image, prompt_model_1, prompt_model_2): |
| STANDARD_SIZE = (1024, 1024) |
| image.thumbnail(STANDARD_SIZE) |
|
|
| annotated_image_model_1, output_text_model_1, timing_1 = detect_qwen( |
| image, prompt_model_1 |
| ) |
| annotated_image_model_2, output_text_model_2, timing_2 = detect_gemma( |
| image, prompt_model_2 |
| ) |
|
|
| return ( |
| annotated_image_model_1, |
| output_text_model_1, |
| timing_1, |
| annotated_image_model_2, |
| output_text_model_2, |
| timing_2, |
| ) |
|
|
|
|
| css_hide_share = """ |
| button#gradio-share-link-button-0 { |
| display: none !important; |
| } |
| """ |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), css=css_hide_share) as demo: |
| gr.Markdown("# Object Detection & Understanding: Qwen vs. Gemma") |
| gr.Markdown( |
| "### Compare object detection, visual grounding, and keypoint detection using natural language prompts with two leading VLMs." |
| ) |
| gr.Markdown(""" |
| *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Gemma 3 4B IT](https://huggingface.co/google/gemma-3-4b-it). For best results, ask the model to return a JSON list in a markdown block. Inspired by the [HF Team's space](https://huggingface.co/spaces/sergiopaniego/vlm_object_understanding), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.* |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| image_input = gr.Image(label="Upload an image", type="pil", height=400) |
| prompt_input_model_1 = gr.Textbox( |
| label=f"Enter your prompt for {model_qwen_name}", |
| placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.", |
| ) |
| prompt_input_model_2 = gr.Textbox( |
| label=f"Enter your prompt for {model_gemma_name}", |
| placeholder="e.g., Detect all red cars. Return a JSON list with 'box_2d' and 'label'.", |
| ) |
| generate_btn = gr.Button(value="Generate") |
|
|
| with gr.Column(scale=1): |
| output_image_model_1 = gr.Image( |
| type="pil", label=f"Annotated image from {model_qwen_name}", height=400 |
| ) |
| output_textbox_model_1 = gr.Textbox( |
| label=f"Model response from {model_qwen_name}", lines=10 |
| ) |
| output_time_model_1 = gr.Markdown() |
|
|
| with gr.Column(scale=1): |
| output_image_model_2 = gr.Image( |
| type="pil", |
| label=f"Annotated image from {model_gemma_name}", |
| height=400, |
| ) |
| output_textbox_model_2 = gr.Textbox( |
| label=f"Model response from {model_gemma_name}", lines=10 |
| ) |
| output_time_model_2 = gr.Markdown() |
|
|
| gr.Markdown("### Examples") |
| |
| prompt_obj_detect = "Detect all objects in this image. For each object, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block." |
| prompt_candy_detect = "Detect all individual candies in this image. For each, provide a 'box_2d' and a 'label'. Return the output as a JSON list inside a markdown block." |
| prompt_car_count = "Count the number of red cars in the image." |
| prompt_candy_count = "Count the number of blue candies in the image." |
| prompt_car_keypoint = "Identify the red cars in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block." |
| prompt_candy_keypoint = "Identify the blue candies in this image. For each, detect its key points and return their positions as 'point_2d' in a JSON list inside a markdown block." |
| prompt_car_ground = "Detect the red car that is leading in this image. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block." |
| prompt_candy_ground = "Detect the blue candy at the top of the group. Return its location with 'box_2d' and 'label' in a JSON list inside a markdown block." |
|
|
|
|
| example_prompts = [ |
| ["examples/example_1.jpg", prompt_obj_detect, prompt_obj_detect], |
| ["examples/example_2.JPG", prompt_candy_detect, prompt_candy_detect], |
| ["examples/example_1.jpg", prompt_car_count, prompt_car_count], |
| ["examples/example_2.JPG", prompt_candy_count, prompt_candy_count], |
| ["examples/example_1.jpg", prompt_car_keypoint, prompt_car_keypoint], |
| ["examples/example_2.JPG", prompt_candy_keypoint, prompt_candy_keypoint], |
| ["examples/example_1.jpg", prompt_car_ground, prompt_car_ground], |
| ["examples/example_2.JPG", prompt_candy_ground, prompt_candy_ground], |
| ] |
|
|
| gr.Examples( |
| examples=example_prompts, |
| inputs=[ |
| image_input, |
| prompt_input_model_1, |
| prompt_input_model_2, |
| ], |
| label="Click an example to populate the input", |
| ) |
|
|
| generate_btn.click( |
| fn=detect, |
| inputs=[ |
| image_input, |
| prompt_input_model_1, |
| prompt_input_model_2, |
| ], |
| outputs=[ |
| output_image_model_1, |
| output_textbox_model_1, |
| output_time_model_1, |
| output_image_model_2, |
| output_textbox_model_2, |
| output_time_model_2, |
| ], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |