# example images:depth-41888,15254,16228, 24144,37777, 22192 # ablate image: 87038 import json from PIL import Image import os import argparse from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition from omini.rotation import RotationConfig, RotationTuner from tqdm import tqdm import torch from diffusers.pipelines import FluxPipeline def evaluate(pipe, condition_type: str, # e.g., "canny" caption_file: str, image_dir: str, save_root_dir: str, num_images: int = 5000, start_index: int = 0): """ Evaluate the model on a subset of the COCO dataset. Args: pipe: The flux pipeline to use for generation caption_file: Path to the COCO captions JSON file image_dir: Directory containing COCO images num_images: Number of images to evaluate on """ os.makedirs(os.path.join(save_root_dir, "generated"), exist_ok=True) os.makedirs(os.path.join(save_root_dir, "resized"), exist_ok=True) os.makedirs(os.path.join(save_root_dir, condition_type), exist_ok=True) # Load data with open(caption_file, "r") as f: coco_data = json.load(f) # Build a mapping: image_id → (filename, captions) id_to_filename = {img["id"]: img["file_name"] for img in coco_data["images"]} captions_by_image = {} for ann in coco_data["annotations"]: img_id = ann["image_id"] captions_by_image.setdefault(img_id, []).append(ann["caption"]) # Take first 5000 images image_ids = list(id_to_filename.keys())[:5000] # Collect data captions_subset = [ { "image_id": img_id, "file_name": id_to_filename[img_id], "captions": captions_by_image.get(img_id, []) } for img_id in image_ids ] for item in tqdm(captions_subset[start_index:start_index+num_images]): image_id = item["image_id"] image_path = os.path.join(image_dir, item["file_name"]) image = Image.open(image_path).convert("RGB") # Resize and center-crop to 512x512 w, h, min_dim = image.size + (min(image.size),) image = image.crop( ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) ).resize((512, 512)) condition_image = convert_to_condition(condition_type, image) condition = Condition(condition_image, condition_type) prompt = item["captions"][0] if item["captions"] else "No caption available." seed_everything(42) # generate image result_img = generate( pipe, prompt=prompt, conditions=[condition], ).images[0] result_img.save(os.path.join(save_root_dir, "generated", f"{image_id}.jpg")) image.save(os.path.join(save_root_dir, "resized", f"{image_id}.jpg")) condition.condition.save(os.path.join(save_root_dir, condition_type, f"{image_id}.jpg")) def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): """ Load rotation adapter weights. Args: path: Directory containing the saved adapter weights adapter_name: Name of the adapter to load strict: Whether to strictly match all keys """ from safetensors.torch import load_file import os import yaml device = transformer.device print(f"device for loading: {device}") # Try to load safetensors first, then fallback to .pth safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") pth_path = os.path.join(path, f"{adapter_name}.pth") if os.path.exists(safetensors_path): state_dict = load_file(safetensors_path) print(f"Loaded rotation adapter from {safetensors_path}") elif os.path.exists(pth_path): state_dict = torch.load(pth_path, map_location=device) print(f"Loaded rotation adapter from {pth_path}") else: raise FileNotFoundError( f"No adapter weights found for '{adapter_name}' in {path}\n" f"Looking for: {safetensors_path} or {pth_path}" ) # # Get the device and dtype of the transformer transformer_device = next(transformer.parameters()).device transformer_dtype = next(transformer.parameters()).dtype state_dict_with_adapter = {} for k, v in state_dict.items(): # Reconstruct the full key with adapter name new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") # Move to target device and dtype # Check if this parameter should keep its original dtype (e.g., indices, masks) if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: # Keep integer/boolean dtypes, only move device state_dict_with_adapter[new_key] = v.to(device=transformer_device) else: # Convert floating point tensors to target dtype and device state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) # Add adapter name back to keys (reverse of what we did in save) state_dict_with_adapter = { k.replace(".rotation.", f".rotation.{adapter_name}."): v for k, v in state_dict.items() } # Load into the model missing, unexpected = transformer.load_state_dict( state_dict_with_adapter, strict=strict ) if missing: print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") # Load config if available config_path = os.path.join(path, f"{adapter_name}_config.yaml") if os.path.exists(config_path): with open(config_path, 'r') as f: config = yaml.safe_load(f) print(f"Loaded config: {config}") total_params = sum(p.numel() for p in state_dict.values()) print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") return state_dict if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset") parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation") parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate") parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')") args = parser.parse_args() START_INDEX = args.start_index NUM_IMAGES = args.num_images # Path to your captions file (change if needed) CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json" IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/" CONDITION_TYPE = args.condition_type SAVE_ROOT_DIR = f"./coco_baseline/results_{CONDITION_TYPE}_1000/" # Load your Flux pipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16) # Replace with your model path ### FOR OMINI pipe.load_lora_weights( "Yuanshi/OminiControl", weight_name=f"experimental/{CONDITION_TYPE}.safetensors", adapter_name=CONDITION_TYPE, ) pipe.fuse_lora(lora_scale=1.0) pipe.unload_lora_weights() # pipe.set_adapters([CONDITION_TYPE]) pipe = pipe.to("cuda") # Evaluate on COCO evaluate( pipe, condition_type=CONDITION_TYPE, caption_file=CAPTION_FILE, image_dir=IMAGE_DIR, save_root_dir=SAVE_ROOT_DIR, num_images=NUM_IMAGES, start_index=START_INDEX, )