import os import sys import torch import argparse import logging from pathlib import Path from typing import List, Dict, Optional, Tuple from collections import OrderedDict import json from datetime import datetime from tqdm import tqdm import numpy as np # Image processing imports from PIL import Image, ImageDraw, ImageFont import torchvision.transforms as transforms # Import your existing model components from models.MIQA_base import get_torch_model, get_timm_model from models.RA_MIQA import RegionVisionTransformer from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES from utils.hf_download_utils import ensure_checkpoint_from_hf # Supported image file extensions SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', 'JPEG', '.png', '.bmp', '.tiff', '.tif'} class MIQAInference: """ Inference wrapper for MIQA models. This class handles model initialization, automatic weight downloading, image preprocessing, batch prediction, and result visualization. """ def __init__(self, task: str, model_name: str = 'ra_miqa', metric_type: str = 'composite', device: Optional[str] = None): """ Initialize the MIQA inference system. Args: task: Task type - 'cls' (classification), 'det' (detection), or 'ins' (instance) model_name: Model architecture to use (default: RA_MIQA for best performance) metric_type: Training objective - 'composite', 'consistency', or 'accuracy' device: Device to run inference on ('cuda' or 'cpu'). Auto-detects if None. """ self.task = task.lower() self.model_name = model_name self.metric_type = metric_type # Setup logging with clean formatting self.logger = self._setup_logger() # Determine computation device if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) self.logger.info(f"🚀 Initializing MIQA Inference System") self.logger.info(f" Task: {self.task.upper()}") self.logger.info(f" Model: {self.model_name}") self.logger.info(f" Metric Type: {self.metric_type}") self.logger.info(f" Device: {self.device}") # Validate configuration self._validate_config() # Initialize model self.model = self._load_model() # Setup image preprocessing pipeline self.transforms1, self.transforms2 = self._get_transforms() self.logger.info("✅ System ready for inference\n") def _setup_logger(self) -> logging.Logger: """Configure logging with both file and console output.""" logger = logging.getLogger('MIQA_Inference') logger.setLevel(logging.INFO) if logger.hasHandlers(): return logger logger.propagate = False # Console handler with clean formatting console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter('%(message)s') console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) return logger def _validate_config(self) -> None: """Validate that the requested configuration is supported.""" if self.metric_type not in ['composite', 'consistency', 'accuracy']: raise ValueError( f"Invalid metric_type '{self.metric_type}'. " f"Supported: ['composite', 'consistency', 'accuracy']" ) if self.task not in MODEL_FILENAMES[self.metric_type]: raise ValueError( f"Invalid task '{self.task}'. " f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}" ) if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]: available = list(MODEL_FILENAMES[self.metric_type][self.task].keys()) raise ValueError( f"Model '{self.model_name}' not available for task '{self.task}'. " f"Available models: {available}" ) def _get_checkpoint_path(self) -> str: """Generate the path where model checkpoint should be stored.""" base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric' base_dir.mkdir(parents=True, exist_ok=True) filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name] return str(base_dir / filename) def _download_weights(self, checkpoint_path: str) -> bool: """ Download model weights if not present locally. Returns: True if weights are available (already existed or successfully downloaded) """ if os.path.exists(checkpoint_path): self.logger.info(f"✓ Found cached model weights") return True self.logger.info( f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, " f"file={Path(checkpoint_path).name}, rev={HF_REVISION}" ) try: ensure_checkpoint_from_hf( repo_id=HF_REPO_ID, filename=Path(checkpoint_path).name, local_dir=str(Path(checkpoint_path).parent), revision=HF_REVISION, ) self.logger.info("✓ Successfully downloaded model weights") return True except Exception as e: self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}") return False def _create_model(self) -> torch.nn.Module: """Create the model architecture.""" if self.model_name == 'ra_miqa': self.logger.info("Building Region-Aware Vision Transformer...") model = RegionVisionTransformer( base_model_name='vit_small_patch16_224', pretrained=False, # We'll load our trained weights mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' ) else: try: self.logger.info(f"Building {self.model_name} from PyTorch...") model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1) except Exception: self.logger.info(f"Building {self.model_name} from timm library...") model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1) return model def _load_model(self) -> torch.nn.Module: """Load model with pre-trained weights.""" checkpoint_path = self._get_checkpoint_path() # Ensure weights are available if not self._download_weights(checkpoint_path): raise RuntimeError("Cannot proceed without model weights") # Create model architecture self.logger.info("🔧 Loading model...") model = self._create_model() # Load weights checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint.get('state_dict', checkpoint) # Remove 'module.' prefix if present (from DataParallel training) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=True) model = model.to(self.device) model.eval() # Set to evaluation mode self.logger.info("✓ Model loaded successfully") return model def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose | None]: """ Return preprocessing transforms based on model type. """ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) SIMPLE_MEAN = (0.5, 0.5, 0.5) SIMPLE_STD = (0.5, 0.5, 0.5) # Default (for single-input backbones) transform_imagenet = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) transform_simple = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) ]) # 1️⃣ CNNs(ResNet / EfficientNet) if any(k in self.model_name for k in ['resnet', 'efficientnet']): return transform_imagenet, None # 2️⃣ ViT elif 'vit' in self.model_name: return transform_simple, None # 3️⃣ ra_miqa elif 'ra_miqa' in self.model_name: transform_1 = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) ]) transform_2 = transforms.Compose([ transforms.Resize(288), transforms.CenterCrop((288, 288)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) return transform_1, transform_2 # fallback else: print(f"[Warning] Unknown model type '{self.model_name}', using ImageNet normalization.") return transform_imagenet, None def _prepare_image(self, image_path: str): """ Load and preprocess a single image. Return (img1, img2, original_img) """ img = Image.open(image_path).convert('RGB') img1 = self.transforms1(img).unsqueeze(0) img2 = self.transforms2(img).unsqueeze(0) if self.transforms2 else None return img1, img2, img @torch.no_grad() def predict_single(self, image_path: str) -> Dict: """ Prediction interface for different backbones. """ img1, img2, original_img = self._prepare_image(image_path) img1 = img1.to(self.device) if img2 is not None: img2 = img2.to(self.device) if img2 is None: output = self.model(img1) else: output = self.model(img1, img2) score = output.item() if torch.is_tensor(output) else float(output) return { 'image_path': image_path, 'image_name': Path(image_path).name, 'quality_score': score, 'original_image': original_img } @torch.no_grad() def predict_batch(self, image_paths: List[str], show_progress: bool = True) -> List[Dict]: """ Run inference on multiple images with progress tracking. Args: image_paths: List of paths to image files show_progress: Whether to display progress bar Returns: List of prediction results for each image """ results = [] # Create progress bar if requested iterator = tqdm(image_paths, desc="Processing images", disable=not show_progress, ncols=80) for img_path in iterator: try: result = self.predict_single(img_path) results.append(result) except Exception as e: self.logger.warning(f"⚠️ Failed to process {img_path}: {str(e)}") results.append({ 'image_path': img_path, 'image_name': Path(img_path).name, 'quality_score': None, 'error': str(e) }) return results def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]: """ Main prediction interface - handles both single images and directories. Args: input_path: Path to an image file or directory containing images show_progress: Whether to show progress bar for batch processing Returns: List of prediction results """ input_path = Path(input_path) # Handle single file if input_path.is_file(): if input_path.suffix.lower() not in SUPPORTED_EXTENSIONS: raise ValueError( f"Unsupported file extension: {input_path.suffix}. " f"Supported: {SUPPORTED_EXTENSIONS}" ) return [self.predict_single(str(input_path))] # Handle directory elif input_path.is_dir(): # Find all supported images in directory image_paths = [] for ext in SUPPORTED_EXTENSIONS: image_paths.extend(input_path.glob(f"*{ext}")) # image_paths.extend(input_path.glob(f"*{ext.upper()}")) image_paths = sorted([str(p) for p in image_paths]) if not image_paths: raise ValueError(f"No supported images found in {input_path}") self.logger.info(f"📁 Found {len(image_paths)} images in directory") return self.predict_batch(image_paths, show_progress) else: raise ValueError(f"Input path does not exist: {input_path}") def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results', score_range: Tuple[float, float] = (0, 100)) -> None: """ Create annotated visualizations of predictions. This generates images with quality scores overlaid in a color-coded box. Low scores appear in red, high scores in green. Args: results: List of prediction results output_dir: Directory to save visualizations score_range: Expected range of quality scores for color normalization """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) self.logger.info(f"\n🎨 Creating visualizations...") for result in tqdm(results, desc="Generating visualizations", ncols=80): if result.get('quality_score') is None: continue # Skip failed predictions img = result['original_image'].copy() draw = ImageDraw.Draw(img) # Prepare score display score = result['quality_score'] score_text = f"Quality: {score:.4f}" # Normalize score to [0, 1] for color interpolation norm_score = (score - score_range[0]) / (score_range[1] - score_range[0]) norm_score = max(0, min(1, norm_score)) # Clamp to [0, 1] # Color interpolation: red (low quality) -> yellow -> green (high quality) if norm_score < 0.5: # Red to yellow r = 255 g = int(255 * (norm_score * 2)) b = 0 else: # Yellow to green r = int(255 * (2 - norm_score * 2)) g = 255 b = 0 color = (r, g, b) # Draw colored box with score box_width = 250 box_height = 50 margin = 10 # Position in top-left corner box_coords = [margin, margin, margin + box_width, margin + box_height] draw.rectangle(box_coords, fill=color) # Add text (try to load a nice font, fall back to default) try: font = ImageFont.truetype("arial.ttf", 30) except: font = ImageFont.load_default() # Calculate text position to center it in the box bbox = draw.textbbox((0, 0), score_text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] text_x = margin + (box_width - text_width) // 2 text_y = margin + (box_height - text_height) // 2 draw.text((text_x, text_y), score_text, fill='black', font=font) # Save annotated image output_file = output_path / f"miqa_{self.model_name}_{result['image_name']}" img.save(output_file) self.logger.info(f"✓ Visualizations saved to: {output_dir}/") def save_results(self, results: List[Dict], output_path: str = 'predictions.json', format: str = 'json') -> None: """ Save prediction results to file. Args: results: List of prediction results output_path: Path to save results format: Output format - 'json' or 'csv' """ # Remove PIL image objects before saving clean_results = [] for r in results: clean_r = {k: v for k, v in r.items() if k != 'original_image'} clean_results.append(clean_r) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if format == 'json': with open(output_path, 'w') as f: json.dump({ 'metadata': { 'task': self.task, 'model': self.model_name, 'metric_type': self.metric_type, 'timestamp': datetime.now().isoformat(), 'num_images': len(clean_results) }, 'predictions': clean_results }, f, indent=2) elif format == 'csv': import csv with open(output_path, 'w', newline='') as f: if clean_results: writer = csv.DictWriter(f, fieldnames=clean_results[0].keys()) writer.writeheader() writer.writerows(clean_results) self.logger.info(f"💾 Results saved to: {output_path}") def print_summary(self, results: List[Dict]) -> None: """Print a formatted summary of prediction results.""" valid_results = [r for r in results if r.get('quality_score') is not None] failed_results = [r for r in results if r.get('quality_score') is None] self.logger.info("\n" + "=" * 80) self.logger.info("PREDICTION SUMMARY") self.logger.info("=" * 80) if valid_results: scores = [r['quality_score'] for r in valid_results] self.logger.info(f"✓ Successfully processed: {len(valid_results)} images") self.logger.info(f" Average quality score: {np.mean(scores):.4f}") self.logger.info(f" Score range: [{np.min(scores):.4f}, {np.max(scores):.4f}]") self.logger.info(f" Standard deviation: {np.std(scores):.4f}") # Show top and bottom quality images sorted_results = sorted(valid_results, key=lambda x: x['quality_score'], reverse=True) self.logger.info("\n🏆 Top 3 quality images:") for i, r in enumerate(sorted_results[:3], 1): self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}") if len(sorted_results) > 3: self.logger.info("\n⚠️ Bottom 3 quality images:") for i, r in enumerate(sorted_results[-3:], 1): self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}") if failed_results: self.logger.info(f"\n❌ Failed to process: {len(failed_results)} images") self.logger.info("=" * 80 + "\n") def main(): """Command-line interface for MIQA inference.""" parser = argparse.ArgumentParser( description='MIQA: Machine-centric Image Quality Assessment', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Predict quality of a single image python img_inference.py --input image.jpg --task cls --model ra_miqa # Process all images in a directory python img_inference.py --input ./assets/demo_images/imagenet_demo --task det --model ra_miqa # Save results and create visualizations python img_inference.py --input /assets/demo_images/imagenet_demo --task ins --save-results --visualize """ ) parser.add_argument('--input', type=str, required=True, help='Path to input image or directory containing images') parser.add_argument('--task', type=str, required=True, choices=['cls', 'det', 'ins'], help='Task type: cls (classification), det (detection), ins (instance)') parser.add_argument('--model', type=str, default='ra_miqa', choices=['ra_miqa'], help='Model architecture (default: ra_miqa; Hub weights are RA-MIQA only)') parser.add_argument('--metric-type', type=str, default='composite', choices=['composite', 'consistency', 'accuracy'], help='Training metric type (default: composite)') parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='Device to run on (auto-detect if not specified)') parser.add_argument('--save-results', action='store_true', help='Save prediction results to file') parser.add_argument('--output-file', type=str, default='predictions.json', help='Output file path for results (default: predictions.json)') parser.add_argument('--output-format', type=str, default='json', choices=['json', 'csv'], help='Output file format (default: json)') parser.add_argument('--visualize', action='store_true', help='Create annotated visualizations of predictions') parser.add_argument('--save-dir', type=str, default='inference_results', help='Directory for save (default: inference_results)') parser.add_argument('--no-progress', action='store_true', help='Disable progress bar') args = parser.parse_args() try: # Initialize inference system miqa = MIQAInference( task=args.task, model_name=args.model, metric_type=args.metric_type, device=args.device ) # Run predictions results = miqa.predict(args.input, show_progress=not args.no_progress) # Print summary miqa.print_summary(results) args.save_dir = Path(args.save_dir)/ 'image' / args.task / args.metric_type args.output_file = f'miqa_{args.model}_'+args.output_file # Save results if requested if args.save_results: miqa.save_results(results, Path(args.save_dir)/args.output_file, args.output_format) # Create visualizations if requested if args.visualize: miqa.visualize_results(results, args.save_dir) except Exception as e: print(f"\n❌ Error: {str(e)}", file=sys.stderr) sys.exit(1) if __name__ == '__main__': main()