Spaces:
Running
Running
| """Unified CLI for LandmarkDiff. | |
| Usage: | |
| landmarkdiff infer IMAGE --procedure rhinoplasty --intensity 65 | |
| landmarkdiff evaluate --test-dir data/test --checkpoint checkpoints/latest | |
| landmarkdiff train --config configs/phaseA.yaml | |
| landmarkdiff demo IMAGE --output demo_report.png | |
| landmarkdiff config --show | |
| landmarkdiff validate IMAGE --output validated.png | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| def cmd_infer(args: argparse.Namespace) -> None: | |
| """Run single-image inference.""" | |
| from pathlib import Path | |
| import cv2 | |
| from landmarkdiff.inference import LandmarkDiffPipeline | |
| image = cv2.imread(args.image) | |
| if image is None: | |
| print(f"ERROR: Cannot read image: {args.image}") | |
| sys.exit(1) | |
| image = cv2.resize(image, (512, 512)) | |
| pipeline = LandmarkDiffPipeline( | |
| mode=args.mode, | |
| controlnet_checkpoint=args.checkpoint, | |
| displacement_model_path=args.displacement_model, | |
| ) | |
| pipeline.load() | |
| result = pipeline.generate( | |
| image, | |
| procedure=args.procedure, | |
| intensity=args.intensity, | |
| seed=args.seed, | |
| ) | |
| out_path = Path(args.output) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| cv2.imwrite(str(out_path), result["output"]) | |
| print(f"Output saved: {out_path}") | |
| if args.watermark: | |
| from landmarkdiff.safety import SafetyValidator | |
| validator = SafetyValidator() | |
| watermarked = validator.apply_watermark(result["output"]) | |
| wm_path = out_path.with_stem(out_path.stem + "_watermarked") | |
| cv2.imwrite(str(wm_path), watermarked) | |
| print(f"Watermarked: {wm_path}") | |
| def cmd_ensemble(args: argparse.Namespace) -> None: | |
| """Run ensemble inference.""" | |
| from landmarkdiff.ensemble import ensemble_inference | |
| ensemble_inference( | |
| image_path=args.image, | |
| procedure=args.procedure, | |
| intensity=args.intensity, | |
| output_dir=args.output, | |
| n_samples=args.n_samples, | |
| strategy=args.strategy, | |
| mode=args.mode, | |
| controlnet_checkpoint=args.checkpoint, | |
| displacement_model_path=args.displacement_model, | |
| seed=args.seed, | |
| ) | |
| def cmd_evaluate(args: argparse.Namespace) -> None: | |
| """Run evaluation on test set. | |
| Delegates to scripts/run_evaluation.py via subprocess to avoid | |
| a circular dependency (landmarkdiff package should not import | |
| from scripts/). | |
| """ | |
| import subprocess | |
| script = str( | |
| __import__("pathlib").Path(__file__).resolve().parent.parent | |
| / "scripts" | |
| / "run_evaluation.py" | |
| ) | |
| cmd = [sys.executable, script, "--test_dir", args.test_dir, "--output", args.output] | |
| if args.checkpoint: | |
| cmd += ["--checkpoint", args.checkpoint] | |
| if args.max_samples: | |
| cmd += ["--max_samples", str(args.max_samples)] | |
| subprocess.run(cmd, check=True) | |
| def cmd_config(args: argparse.Namespace) -> None: | |
| """Show or validate configuration.""" | |
| from landmarkdiff.config import ExperimentConfig, load_config, validate_config | |
| if args.file: | |
| config = load_config(args.file) | |
| else: | |
| config = ExperimentConfig() | |
| if args.validate: | |
| warnings = validate_config(config) | |
| if warnings: | |
| print("Validation warnings:") | |
| for w in warnings: | |
| print(f" - {w}") | |
| else: | |
| print("Configuration valid (no warnings).") | |
| else: | |
| from dataclasses import asdict | |
| import yaml | |
| print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False)) | |
| def cmd_validate(args: argparse.Namespace) -> None: | |
| """Run safety validation on an output image.""" | |
| import cv2 | |
| from landmarkdiff.safety import SafetyValidator | |
| input_img = cv2.imread(args.input) | |
| output_img = cv2.imread(args.output_image) | |
| if input_img is None or output_img is None: | |
| print("ERROR: Cannot read input or output image.") | |
| sys.exit(1) | |
| validator = SafetyValidator( | |
| watermark_enabled=args.watermark, | |
| ) | |
| result = validator.validate( | |
| input_image=input_img, | |
| output_image=output_img, | |
| face_confidence=args.face_confidence, | |
| ) | |
| print(result.summary()) | |
| if not result.passed: | |
| sys.exit(1) | |
| def cmd_version(args: argparse.Namespace) -> None: | |
| """Print version info.""" | |
| from landmarkdiff import __version__ | |
| print(f"LandmarkDiff v{__version__}") | |
| def main(argv: list[str] | None = None) -> None: | |
| """Main CLI entry point.""" | |
| parser = argparse.ArgumentParser( | |
| prog="landmarkdiff", | |
| description="LandmarkDiff: Facial surgery outcome prediction via latent diffusion", | |
| ) | |
| subparsers = parser.add_subparsers(dest="command", help="Available commands") | |
| # --- infer --- | |
| p_infer = subparsers.add_parser("infer", help="Run single-image inference") | |
| p_infer.add_argument("image", help="Input face image path") | |
| p_infer.add_argument("--procedure", default="rhinoplasty", | |
| choices=["rhinoplasty", "blepharoplasty", "rhytidectomy", "orthognathic", "brow_lift", "mentoplasty"]) | |
| p_infer.add_argument("--intensity", type=float, default=65.0) | |
| p_infer.add_argument("--output", default="output.png") | |
| p_infer.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"]) | |
| p_infer.add_argument("--checkpoint", default=None) | |
| p_infer.add_argument("--displacement-model", default=None) | |
| p_infer.add_argument("--seed", type=int, default=42) | |
| p_infer.add_argument("--watermark", action="store_true") | |
| p_infer.set_defaults(func=cmd_infer) | |
| # --- ensemble --- | |
| p_ensemble = subparsers.add_parser("ensemble", help="Run ensemble inference") | |
| p_ensemble.add_argument("image", help="Input face image path") | |
| p_ensemble.add_argument("--procedure", default="rhinoplasty") | |
| p_ensemble.add_argument("--intensity", type=float, default=65.0) | |
| p_ensemble.add_argument("--output", default="ensemble_output") | |
| p_ensemble.add_argument("--n-samples", type=int, default=5) | |
| p_ensemble.add_argument("--strategy", default="best_of_n", | |
| choices=["pixel_average", "weighted_average", "best_of_n", "median"]) | |
| p_ensemble.add_argument("--mode", default="tps", choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"]) | |
| p_ensemble.add_argument("--checkpoint", default=None) | |
| p_ensemble.add_argument("--displacement-model", default=None) | |
| p_ensemble.add_argument("--seed", type=int, default=42) | |
| p_ensemble.set_defaults(func=cmd_ensemble) | |
| # --- evaluate --- | |
| p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set") | |
| p_eval.add_argument("--test-dir", required=True) | |
| p_eval.add_argument("--output", default="eval_results") | |
| p_eval.add_argument("--checkpoint", default=None) | |
| p_eval.add_argument("--max-samples", type=int, default=0) | |
| p_eval.set_defaults(func=cmd_evaluate) | |
| # --- config --- | |
| p_config = subparsers.add_parser("config", help="Show or validate configuration") | |
| p_config.add_argument("--file", default=None, help="YAML config file") | |
| p_config.add_argument("--validate", action="store_true") | |
| p_config.set_defaults(func=cmd_config) | |
| # --- validate --- | |
| p_validate = subparsers.add_parser("validate", help="Run safety validation") | |
| p_validate.add_argument("input", help="Original input image") | |
| p_validate.add_argument("output_image", help="Generated output image") | |
| p_validate.add_argument("--watermark", action="store_true") | |
| p_validate.add_argument("--face-confidence", type=float, default=1.0) | |
| p_validate.set_defaults(func=cmd_validate) | |
| # --- version --- | |
| p_version = subparsers.add_parser("version", help="Print version") | |
| p_version.set_defaults(func=cmd_version) | |
| args = parser.parse_args(argv) | |
| if not hasattr(args, "func"): | |
| parser.print_help() | |
| sys.exit(1) | |
| args.func(args) | |
| if __name__ == "__main__": | |
| main() | |