| from functools import partial |
| import gradio as gr |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
| import spaces |
| import os |
| import tempfile |
| from PIL import Image, ImageDraw |
| import re |
|
|
| |
| print("Loading model and tokenizer...") |
| model_name = "deepseek-ai/DeepSeek-OCR" |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| |
| model = AutoModel.from_pretrained( |
| model_name, |
| _attn_implementation="flash_attention_2", |
| trust_remote_code=True, |
| use_safetensors=True, |
| ) |
| model = model.eval() |
| print("β
Model loaded successfully.") |
|
|
| |
| def find_result_image(path): |
| for filename in os.listdir(path): |
| if "grounding" in filename or "result" in filename: |
| try: |
| image_path = os.path.join(path, filename) |
| return Image.open(image_path) |
| except Exception as e: |
| print(f"Error opening result image {filename}: {e}") |
| return None |
|
|
| |
| @spaces.GPU |
| def process_ocr_task(image, model_size, ref_text, task_type): |
| """ |
| Processes an image with DeepSeek-OCR for all supported tasks. |
| Now draws ALL detected bounding boxes for ANY task. |
| """ |
| if image is None: |
| return "Please upload an image first.", None |
|
|
| print("π Moving model to GPU...") |
| model_gpu = model.cuda().to(torch.bfloat16) |
| print("β
Model is on GPU.") |
|
|
| with tempfile.TemporaryDirectory() as output_path: |
| |
| if task_type == "π Free OCR": |
| prompt = "<image>\nFree OCR." |
| elif task_type == "π Convert to Markdown": |
| prompt = "<image>\n<|grounding|>Convert the document to markdown." |
| elif task_type == "π Parse Figure": |
| prompt = "<image>\nParse the figure." |
| elif task_type == "π Locate Object by Reference": |
| if not ref_text or ref_text.strip() == "": |
| raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") |
| prompt = f"<image>\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." |
| else: |
| prompt = "<image>\nFree OCR." |
|
|
| temp_image_path = os.path.join(output_path, "temp_image.png") |
| image.save(temp_image_path) |
|
|
| |
| size_configs = { |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, |
| "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, |
| } |
| config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) |
|
|
| print(f"π Running inference with prompt: {prompt}") |
| text_result = model_gpu.infer( |
| tokenizer, |
| prompt=prompt, |
| image_file=temp_image_path, |
| output_path=output_path, |
| base_size=config["base_size"], |
| image_size=config["image_size"], |
| crop_mode=config["crop_mode"], |
| save_results=True, |
| test_compress=True, |
| eval_mode=True, |
| ) |
|
|
| print(f"====\nπ Text Result: {text_result}\n====") |
|
|
| |
| result_image_pil = None |
| |
| |
| pattern = re.compile(r"<\|det\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|/det\|>") |
| matches = list(pattern.finditer(text_result)) |
|
|
| if matches: |
| print(f"β
Found {len(matches)} bounding box(es). Drawing on the original image.") |
| |
| |
| image_with_bboxes = image.copy() |
| |
| w, h = image.size |
|
|
| for match in matches: |
| |
| coords_norm = [int(c) for c in match.groups()] |
| x1_norm, y1_norm, x2_norm, y2_norm = coords_norm |
| |
| |
| x1 = int(x1_norm / 1000 * w) |
| y1 = int(y1_norm / 1000 * h) |
| x2 = int(x2_norm / 1000 * w) |
| y2 = int(y2_norm / 1000 * h) |
| |
| |
| image_with_bboxes = image_with_bboxes.crop([x1, y1, x2, y2]) |
|
|
| result_image_pil = image_with_bboxes |
| else: |
| |
| print("β οΈ No bounding box coordinates found in text result. Falling back to search for a result image file.") |
| result_image_pil = find_result_image(output_path) |
| |
| return text_result, result_image_pil |
|
|
|
|
| |
| with gr.Blocks(title="Text Extraction Demo", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # π³ Full Demo of DeepSeek-OCR π³ |
| |
| Use the tabs below to switch between Free OCR and Locate modes. |
| """ |
| ) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Free OCR"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| free_image = gr.Image(type="pil", label="πΌοΈ Upload Image", sources=["upload", "clipboard"]) |
| free_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βοΈ Resolution Size") |
| free_btn = gr.Button("Run Free OCR", variant="primary") |
|
|
| with gr.Column(scale=2): |
| free_output_text = gr.Textbox(label="π Text Result", lines=15, show_copy_button=True) |
| free_output_image = gr.Image(label="πΌοΈ Image Result (if any)", type="pil") |
|
|
| |
| free_ocr = partial(process_ocr_task, task_type="π Free OCR", ref_text="") |
| free_btn.click(fn=free_ocr, inputs=[free_image, free_model_size], outputs=[free_output_text, free_output_image]) |
|
|
| with gr.TabItem("Locate"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| loc_image = gr.Image(type="pil", label="πΌοΈ Upload Image", sources=["upload", "clipboard"]) |
| loc_model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Base", label="βοΈ Resolution Size") |
| |
| loc_btn = gr.Button("Locate", variant="primary") |
|
|
| with gr.Column(scale=2): |
| loc_output_text = gr.Textbox(label="π Text Result", lines=15, show_copy_button=True) |
| loc_output_image = gr.Image(label="πΌοΈ Image Result (if any)", type="pil") |
|
|
| |
| pets_detection = partial(process_ocr_task, task_type="π Locate Object by Reference", ref_text="pets") |
| loc_btn.click(fn=pets_detection, inputs=[loc_image, loc_model_size], outputs=[loc_output_text, loc_output_image]) |
|
|
| |
| gr.Examples( |
| examples=[ |
| ["doc_markdown.png", "Gundam (Recommended)", "", "π Convert to Markdown"], |
| ["chart.png", "Gundam (Recommended)", "", "π Parse Figure"], |
| ["teacher.jpg", "Base", "the teacher", "π Locate Object by Reference"], |
| ["math_locate.jpg", "Small", "20-10", "π Locate Object by Reference"], |
| ["receipt.jpg", "Base", "", "π Free OCR"], |
| ], |
| inputs=[free_image, free_model_size], |
| outputs=[free_output_text, free_output_image], |
| fn=process_ocr_task, |
| cache_examples=False, |
| ) |
|
|
| |
| if __name__ == "__main__": |
| if not os.path.exists("examples"): |
| os.makedirs("examples") |
| |
| |
| |
| demo.queue(max_size=20).launch(share=True) |