"""Python client for the LandmarkDiff REST API. Provides a clean interface for interacting with the FastAPI server, handling image encoding/decoding, error handling, and session management. Usage: from landmarkdiff.api_client import LandmarkDiffClient client = LandmarkDiffClient("http://localhost:8000") # Single prediction result = client.predict("patient.png", procedure="rhinoplasty", intensity=65) result.save("output.png") # Face analysis analysis = client.analyze("patient.png") print(f"Fitzpatrick type: {analysis['fitzpatrick_type']}") # Batch processing results = client.batch_predict( ["patient1.png", "patient2.png"], procedure="blepharoplasty", ) """ from __future__ import annotations import base64 from dataclasses import dataclass, field from pathlib import Path from typing import Any import cv2 import numpy as np class LandmarkDiffAPIError(Exception): """Base exception for LandmarkDiff API errors.""" pass @dataclass class PredictionResult: """Result from a single prediction.""" output_image: np.ndarray procedure: str intensity: float confidence: float = 0.0 landmarks_before: list[Any] | None = None landmarks_after: list[Any] | None = None metrics: dict[str, float] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) def save(self, path: str | Path, fmt: str = ".png") -> None: """Save the output image to a file.""" cv2.imwrite(str(path), self.output_image) def show(self) -> None: """Display the output image (requires GUI).""" cv2.imshow("LandmarkDiff Prediction", self.output_image) cv2.waitKey(0) cv2.destroyAllWindows() class LandmarkDiffClient: """Client for the LandmarkDiff REST API. Args: base_url: Server URL (e.g. "http://localhost:8000"). timeout: Request timeout in seconds. """ def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None: self.base_url = base_url.rstrip("/") self.timeout = timeout self._session = None def _get_session(self) -> Any: """Lazy-initialize requests session.""" if self._session is None: try: import requests except ImportError as e: raise ImportError("requests required. Install with: pip install requests") from e self._session = requests.Session() return self._session def _read_image(self, image_path: str | Path) -> bytes: """Read image file as bytes.""" path = Path(image_path) if not path.exists(): raise FileNotFoundError(f"Image not found: {path}") return path.read_bytes() def _decode_base64_image(self, b64_string: str) -> np.ndarray: """Decode a base64-encoded image to numpy array.""" img_bytes = base64.b64decode(b64_string) arr = np.frombuffer(img_bytes, np.uint8) img = cv2.imdecode(arr, cv2.IMREAD_COLOR) if img is None: raise ValueError("Failed to decode base64 image") return img # ------------------------------------------------------------------ # API methods # ------------------------------------------------------------------ def health(self) -> dict[str, Any]: """Check server health. Returns: Dict with status and version info. Raises: LandmarkDiffAPIError: If server is unreachable or returns an error. """ session = self._get_session() try: resp = session.get(f"{self.base_url}/health", timeout=self.timeout) resp.raise_for_status() return resp.json() except Exception as e: import requests if isinstance(e, requests.ConnectionError): raise LandmarkDiffAPIError( f"Cannot connect to LandmarkDiff server at {self.base_url}. " f"Make sure the server is running (python -m landmarkdiff serve)." ) from None elif isinstance(e, requests.HTTPError): raise LandmarkDiffAPIError( f"Server returned error {e.response.status_code}: {e.response.text[:200]}" ) from None else: raise def procedures(self) -> list[str]: """List available surgical procedures. Returns: List of procedure names. Raises: LandmarkDiffAPIError: If server is unreachable or returns an error. """ session = self._get_session() try: resp = session.get(f"{self.base_url}/procedures", timeout=self.timeout) resp.raise_for_status() return resp.json().get("procedures", []) except Exception as e: import requests if isinstance(e, requests.ConnectionError): raise LandmarkDiffAPIError( f"Cannot connect to LandmarkDiff server at {self.base_url}. " f"Make sure the server is running (python -m landmarkdiff serve)." ) from None elif isinstance(e, requests.HTTPError): raise LandmarkDiffAPIError( f"Server returned error {e.response.status_code}: {e.response.text[:200]}" ) from None else: raise def predict( self, image_path: str | Path, procedure: str = "rhinoplasty", intensity: float = 65.0, seed: int = 42, ) -> PredictionResult: """Run surgical outcome prediction. Args: image_path: Path to input face image. procedure: Surgical procedure type. intensity: Intensity of the modification (0-100). seed: Random seed for reproducibility. Returns: PredictionResult with output image and metadata. """ session = self._get_session() image_bytes = self._read_image(image_path) files = {"image": ("image.png", image_bytes, "image/png")} data = { "procedure": procedure, "intensity": str(intensity), "seed": str(seed), } resp = session.post( f"{self.base_url}/predict", files=files, data=data, timeout=self.timeout ) try: resp.raise_for_status() result = resp.json() # Decode output image output_img = self._decode_base64_image(result["output_image"]) return PredictionResult( output_image=output_img, procedure=procedure, intensity=intensity, confidence=result.get("confidence", 0.0), metrics=result.get("metrics", {}), metadata=result.get("metadata", {}), ) except Exception as e: import requests if isinstance(e, requests.ConnectionError): raise LandmarkDiffAPIError( f"Cannot connect to LandmarkDiff server at {self.base_url}. " f"Make sure the server is running (python -m landmarkdiff serve)." ) from None elif isinstance(e, requests.HTTPError): raise LandmarkDiffAPIError( f"Server returned error {e.response.status_code}: {e.response.text[:200]}" ) from None else: raise def analyze(self, image_path: str | Path) -> dict[str, Any]: """Analyze a face image without generating a prediction. Returns face landmarks, Fitzpatrick type, pose estimation, etc. Args: image_path: Path to input face image. Returns: Dict with analysis results. Raises: LandmarkDiffAPIError: If server is unreachable or returns an error. """ session = self._get_session() image_bytes = self._read_image(image_path) files = {"image": ("image.png", image_bytes, "image/png")} try: resp = session.post(f"{self.base_url}/analyze", files=files, timeout=self.timeout) resp.raise_for_status() return resp.json() except Exception as e: import requests if isinstance(e, requests.ConnectionError): raise LandmarkDiffAPIError( f"Cannot connect to LandmarkDiff server at {self.base_url}. " f"Make sure the server is running (python -m landmarkdiff serve)." ) from None elif isinstance(e, requests.HTTPError): raise LandmarkDiffAPIError( f"Server returned error {e.response.status_code}: {e.response.text[:200]}" ) from None else: raise def batch_predict( self, image_paths: list[str | Path], procedure: str = "rhinoplasty", intensity: float = 65.0, seed: int = 42, ) -> list[PredictionResult]: """Run batch prediction on multiple images. Args: image_paths: List of image file paths. procedure: Procedure to apply to all images. intensity: Intensity for all images. seed: Base random seed. Returns: List of PredictionResult objects. """ results = [] for i, path in enumerate(image_paths): try: result = self.predict( path, procedure=procedure, intensity=intensity, seed=seed + i, ) results.append(result) except Exception as e: # Create a failed result results.append( PredictionResult( output_image=np.zeros((512, 512, 3), dtype=np.uint8), procedure=procedure, intensity=intensity, metadata={"error": str(e), "path": str(path)}, ) ) return results def close(self) -> None: """Close the HTTP session.""" if self._session is not None: self._session.close() self._session = None def __enter__(self) -> LandmarkDiffClient: return self def __exit__(self, *args: Any) -> None: self.close() def __repr__(self) -> str: return f"LandmarkDiffClient(base_url='{self.base_url}')"