LandmarkDiff / landmarkdiff /api_client.py
dreamlessx's picture
Update landmarkdiff/api_client.py to v0.3.2
21efdcf verified
"""Python client for the LandmarkDiff REST API.
Provides a clean interface for interacting with the FastAPI server,
handling image encoding/decoding, error handling, and session management.
Usage:
from landmarkdiff.api_client import LandmarkDiffClient
client = LandmarkDiffClient("http://localhost:8000")
# Single prediction
result = client.predict("patient.png", procedure="rhinoplasty", intensity=65)
result.save("output.png")
# Face analysis
analysis = client.analyze("patient.png")
print(f"Fitzpatrick type: {analysis['fitzpatrick_type']}")
# Batch processing
results = client.batch_predict(
["patient1.png", "patient2.png"],
procedure="blepharoplasty",
)
"""
from __future__ import annotations
import base64
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import cv2
import numpy as np
class LandmarkDiffAPIError(Exception):
"""Base exception for LandmarkDiff API errors."""
pass
@dataclass
class PredictionResult:
"""Result from a single prediction."""
output_image: np.ndarray
procedure: str
intensity: float
confidence: float = 0.0
landmarks_before: list[Any] | None = None
landmarks_after: list[Any] | None = None
metrics: dict[str, float] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
def save(self, path: str | Path, fmt: str = ".png") -> None:
"""Save the output image to a file."""
cv2.imwrite(str(path), self.output_image)
def show(self) -> None:
"""Display the output image (requires GUI)."""
cv2.imshow("LandmarkDiff Prediction", self.output_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
class LandmarkDiffClient:
"""Client for the LandmarkDiff REST API.
Args:
base_url: Server URL (e.g. "http://localhost:8000").
timeout: Request timeout in seconds.
"""
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._session = None
def _get_session(self) -> Any:
"""Lazy-initialize requests session."""
if self._session is None:
try:
import requests
except ImportError as e:
raise ImportError("requests required. Install with: pip install requests") from e
self._session = requests.Session()
return self._session
def _read_image(self, image_path: str | Path) -> bytes:
"""Read image file as bytes."""
path = Path(image_path)
if not path.exists():
raise FileNotFoundError(f"Image not found: {path}")
return path.read_bytes()
def _decode_base64_image(self, b64_string: str) -> np.ndarray:
"""Decode a base64-encoded image to numpy array."""
img_bytes = base64.b64decode(b64_string)
arr = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Failed to decode base64 image")
return img
# ------------------------------------------------------------------
# API methods
# ------------------------------------------------------------------
def health(self) -> dict[str, Any]:
"""Check server health.
Returns:
Dict with status and version info.
Raises:
LandmarkDiffAPIError: If server is unreachable or returns an error.
"""
session = self._get_session()
try:
resp = session.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
return resp.json()
except Exception as e:
import requests
if isinstance(e, requests.ConnectionError):
raise LandmarkDiffAPIError(
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
f"Make sure the server is running (python -m landmarkdiff serve)."
) from None
elif isinstance(e, requests.HTTPError):
raise LandmarkDiffAPIError(
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
) from None
else:
raise
def procedures(self) -> list[str]:
"""List available surgical procedures.
Returns:
List of procedure names.
Raises:
LandmarkDiffAPIError: If server is unreachable or returns an error.
"""
session = self._get_session()
try:
resp = session.get(f"{self.base_url}/procedures", timeout=self.timeout)
resp.raise_for_status()
return resp.json().get("procedures", [])
except Exception as e:
import requests
if isinstance(e, requests.ConnectionError):
raise LandmarkDiffAPIError(
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
f"Make sure the server is running (python -m landmarkdiff serve)."
) from None
elif isinstance(e, requests.HTTPError):
raise LandmarkDiffAPIError(
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
) from None
else:
raise
def predict(
self,
image_path: str | Path,
procedure: str = "rhinoplasty",
intensity: float = 65.0,
seed: int = 42,
) -> PredictionResult:
"""Run surgical outcome prediction.
Args:
image_path: Path to input face image.
procedure: Surgical procedure type.
intensity: Intensity of the modification (0-100).
seed: Random seed for reproducibility.
Returns:
PredictionResult with output image and metadata.
"""
session = self._get_session()
image_bytes = self._read_image(image_path)
files = {"image": ("image.png", image_bytes, "image/png")}
data = {
"procedure": procedure,
"intensity": str(intensity),
"seed": str(seed),
}
resp = session.post(
f"{self.base_url}/predict", files=files, data=data, timeout=self.timeout
)
try:
resp.raise_for_status()
result = resp.json()
# Decode output image
output_img = self._decode_base64_image(result["output_image"])
return PredictionResult(
output_image=output_img,
procedure=procedure,
intensity=intensity,
confidence=result.get("confidence", 0.0),
metrics=result.get("metrics", {}),
metadata=result.get("metadata", {}),
)
except Exception as e:
import requests
if isinstance(e, requests.ConnectionError):
raise LandmarkDiffAPIError(
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
f"Make sure the server is running (python -m landmarkdiff serve)."
) from None
elif isinstance(e, requests.HTTPError):
raise LandmarkDiffAPIError(
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
) from None
else:
raise
def analyze(self, image_path: str | Path) -> dict[str, Any]:
"""Analyze a face image without generating a prediction.
Returns face landmarks, Fitzpatrick type, pose estimation, etc.
Args:
image_path: Path to input face image.
Returns:
Dict with analysis results.
Raises:
LandmarkDiffAPIError: If server is unreachable or returns an error.
"""
session = self._get_session()
image_bytes = self._read_image(image_path)
files = {"image": ("image.png", image_bytes, "image/png")}
try:
resp = session.post(f"{self.base_url}/analyze", files=files, timeout=self.timeout)
resp.raise_for_status()
return resp.json()
except Exception as e:
import requests
if isinstance(e, requests.ConnectionError):
raise LandmarkDiffAPIError(
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
f"Make sure the server is running (python -m landmarkdiff serve)."
) from None
elif isinstance(e, requests.HTTPError):
raise LandmarkDiffAPIError(
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
) from None
else:
raise
def batch_predict(
self,
image_paths: list[str | Path],
procedure: str = "rhinoplasty",
intensity: float = 65.0,
seed: int = 42,
) -> list[PredictionResult]:
"""Run batch prediction on multiple images.
Args:
image_paths: List of image file paths.
procedure: Procedure to apply to all images.
intensity: Intensity for all images.
seed: Base random seed.
Returns:
List of PredictionResult objects.
"""
results = []
for i, path in enumerate(image_paths):
try:
result = self.predict(
path,
procedure=procedure,
intensity=intensity,
seed=seed + i,
)
results.append(result)
except Exception as e:
# Create a failed result
results.append(
PredictionResult(
output_image=np.zeros((512, 512, 3), dtype=np.uint8),
procedure=procedure,
intensity=intensity,
metadata={"error": str(e), "path": str(path)},
)
)
return results
def close(self) -> None:
"""Close the HTTP session."""
if self._session is not None:
self._session.close()
self._session = None
def __enter__(self) -> LandmarkDiffClient:
return self
def __exit__(self, *args: Any) -> None:
self.close()
def __repr__(self) -> str:
return f"LandmarkDiffClient(base_url='{self.base_url}')"