laguna-vision / handler.py
aaronkazah's picture
Upload handler.py with huggingface_hub
d13bab6 verified
from __future__ import annotations
import asyncio
import base64
import io
import json
import os
import tempfile
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import requests
from PIL import Image
from lagunavision.visual_pipeline import LagunaVisionImagePipeline, VisualProjectorSpec
class EndpointHandler:
"""Hugging Face Inference Endpoint handler for Laguna Vision checkpoints."""
def __init__(self, path: str = "") -> None:
self.model_dir = Path(path or ".").resolve()
self.checkpoint_dir = self._resolve_checkpoint_dir()
self.pipeline: LagunaVisionImagePipeline | None = None
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
payload = data.get("inputs", data)
if not isinstance(payload, dict):
raise ValueError("inputs must be an object with image and question fields")
payload = _normalize_payload(payload)
question = str(payload.get("question") or "").strip()
if not question:
raise ValueError("question is required")
image_value = payload.get("image")
if image_value is None:
raise ValueError(
"image is required as base64, a data URI, an HTTPS URL, a local path, "
"or OpenAI-style messages content with image_url"
)
context = str(payload.get("context") or "")
max_new_tokens = int(payload.get("max_new_tokens") or os.environ.get("LAGUNA_MAX_NEW_TOKENS", "128"))
image = _load_image(image_value)
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
image.save(tmp.name)
answer = asyncio.run(self._answer(Path(tmp.name), question, context, max_new_tokens))
return {"answer": answer, "checkpoint": str(self.checkpoint_dir.relative_to(self.model_dir))}
async def _answer(self, image: Path, question: str, context: str, max_new_tokens: int) -> str:
if self.pipeline is None:
self.pipeline = await self._load_pipeline()
return await self.pipeline.answer_image(
image=image,
question=question,
context=context,
max_new_tokens=max_new_tokens,
)
async def _load_pipeline(self) -> LagunaVisionImagePipeline:
spec_path = self.checkpoint_dir / "projector_spec.json"
if not spec_path.exists():
raise FileNotFoundError(f"missing projector_spec.json in {self.checkpoint_dir}")
spec_row = json.loads(spec_path.read_text(encoding="utf-8"))
spec = VisualProjectorSpec(
input_dim=int(spec_row["input_dim"]),
embedding_dim=int(spec_row["embedding_dim"]),
hidden_dim=int(spec_row["hidden_dim"]),
projector=spec_row.get("projector", "mlp"),
visual_tokens=int(spec_row.get("visual_tokens", 64)),
encoder=spec_row.get("encoder", "hf"),
encoder_id=spec_row.get("encoder_id", ""),
max_tiles=int(spec_row.get("max_tiles", 4)),
patch_px=int(spec_row.get("patch_px", 32)),
)
lora_dir = self.checkpoint_dir / "lora"
return await LagunaVisionImagePipeline.from_checkpoint(
checkpoint=self.checkpoint_dir / "projector.pt",
spec=spec,
backbone_name=os.environ.get("LAGUNA_BACKBONE") or spec_row.get("backbone", "laguna"),
model_id=os.environ.get("LAGUNA_MODEL_ID") or spec_row["model_id"],
device=os.environ.get("LAGUNA_DEVICE", "auto"),
vision_device=os.environ.get("LAGUNA_VISION_DEVICE", "auto"),
lora_dir=lora_dir if lora_dir.exists() else None,
)
def _resolve_checkpoint_dir(self) -> Path:
requested = os.environ.get("LAGUNA_CHECKPOINT_PATH", "latest").strip("/")
candidates = [self.model_dir / requested, self.model_dir]
for candidate in candidates:
if (candidate / "projector.pt").exists() and (candidate / "projector_spec.json").exists():
return candidate
matches = sorted(self.model_dir.glob("**/projector.pt"), key=lambda path: path.stat().st_mtime, reverse=True)
if matches:
return matches[0].parent
raise FileNotFoundError(
f"no Laguna Vision checkpoint found under {self.model_dir}; expected latest/projector.pt"
)
def _load_image(value: Any) -> Image.Image:
if isinstance(value, bytes):
return Image.open(io.BytesIO(value)).convert("RGB")
if not isinstance(value, str):
raise ValueError("image must be bytes or a string")
if value.startswith("data:image/"):
_, encoded = value.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
parsed = urlparse(value)
if parsed.scheme in {"http", "https"}:
response = requests.get(value, timeout=20)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert("RGB")
path = Path(value)
if path.exists():
return Image.open(path).convert("RGB")
return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB")
def _normalize_payload(payload: dict[str, Any]) -> dict[str, Any]:
if payload.get("messages") is None:
return payload
question_parts: list[str] = []
image_value: Any = payload.get("image")
for message in payload.get("messages") or []:
if not isinstance(message, dict):
continue
content = message.get("content")
if isinstance(content, str):
question_parts.append(content)
continue
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type in {"text", "input_text"} and item.get("text"):
question_parts.append(str(item["text"]))
elif item_type in {"image_url", "input_image"}:
image_url = item.get("image_url")
if isinstance(image_url, dict):
image_value = image_url.get("url") or image_url.get("image_url") or image_value
else:
image_value = image_url or item.get("url") or image_value
normalized = dict(payload)
normalized.setdefault("question", "\n".join(part.strip() for part in question_parts if part.strip()))
if image_value is not None:
normalized["image"] = image_value
return normalized