| import os |
| import sys |
| import torch |
| import argparse |
| import logging |
| from pathlib import Path |
| from typing import List, Dict, Optional, Tuple |
| from collections import OrderedDict |
| import json |
| from datetime import datetime |
| from tqdm import tqdm |
| import numpy as np |
|
|
| |
| from PIL import Image, ImageDraw, ImageFont |
| import torchvision.transforms as transforms |
|
|
| |
| from models.MIQA_base import get_torch_model, get_timm_model |
| from models.RA_MIQA import RegionVisionTransformer |
| from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES |
| from utils.hf_download_utils import ensure_checkpoint_from_hf |
|
|
| |
| SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', 'JPEG', '.png', '.bmp', '.tiff', '.tif'} |
|
|
|
|
| class MIQAInference: |
| """ |
| Inference wrapper for MIQA models. |
| |
| This class handles model initialization, automatic weight downloading, |
| image preprocessing, batch prediction, and result visualization. |
| """ |
|
|
| def __init__(self, task: str, model_name: str = 'ra_miqa', |
| metric_type: str = 'composite', device: Optional[str] = None): |
| """ |
| Initialize the MIQA inference system. |
| |
| Args: |
| task: Task type - 'cls' (classification), 'det' (detection), or 'ins' (instance) |
| model_name: Model architecture to use (default: RA_MIQA for best performance) |
| metric_type: Training objective - 'composite', 'consistency', or 'accuracy' |
| device: Device to run inference on ('cuda' or 'cpu'). Auto-detects if None. |
| """ |
| self.task = task.lower() |
| self.model_name = model_name |
| self.metric_type = metric_type |
|
|
| |
| self.logger = self._setup_logger() |
|
|
| |
| if device is None: |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| self.device = torch.device(device) |
|
|
| self.logger.info(f"🚀 Initializing MIQA Inference System") |
| self.logger.info(f" Task: {self.task.upper()}") |
| self.logger.info(f" Model: {self.model_name}") |
| self.logger.info(f" Metric Type: {self.metric_type}") |
| self.logger.info(f" Device: {self.device}") |
|
|
| |
| self._validate_config() |
|
|
| |
| self.model = self._load_model() |
|
|
| |
| self.transforms1, self.transforms2 = self._get_transforms() |
|
|
| self.logger.info("✅ System ready for inference\n") |
|
|
| def _setup_logger(self) -> logging.Logger: |
| """Configure logging with both file and console output.""" |
| logger = logging.getLogger('MIQA_Inference') |
| logger.setLevel(logging.INFO) |
|
|
| if logger.hasHandlers(): |
| return logger |
|
|
| logger.propagate = False |
|
|
| |
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setLevel(logging.INFO) |
| console_formatter = logging.Formatter('%(message)s') |
| console_handler.setFormatter(console_formatter) |
| logger.addHandler(console_handler) |
|
|
| return logger |
|
|
| def _validate_config(self) -> None: |
| """Validate that the requested configuration is supported.""" |
|
|
| if self.metric_type not in ['composite', 'consistency', 'accuracy']: |
| raise ValueError( |
| f"Invalid metric_type '{self.metric_type}'. " |
| f"Supported: ['composite', 'consistency', 'accuracy']" |
| ) |
|
|
| if self.task not in MODEL_FILENAMES[self.metric_type]: |
| raise ValueError( |
| f"Invalid task '{self.task}'. " |
| f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}" |
| ) |
|
|
| if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]: |
| available = list(MODEL_FILENAMES[self.metric_type][self.task].keys()) |
| raise ValueError( |
| f"Model '{self.model_name}' not available for task '{self.task}'. " |
| f"Available models: {available}" |
| ) |
| def _get_checkpoint_path(self) -> str: |
| """Generate the path where model checkpoint should be stored.""" |
| base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric' |
| base_dir.mkdir(parents=True, exist_ok=True) |
|
|
| filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name] |
| return str(base_dir / filename) |
|
|
| def _download_weights(self, checkpoint_path: str) -> bool: |
| """ |
| Download model weights if not present locally. |
| |
| Returns: |
| True if weights are available (already existed or successfully downloaded) |
| """ |
| if os.path.exists(checkpoint_path): |
| self.logger.info(f"✓ Found cached model weights") |
| return True |
|
|
| self.logger.info( |
| f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, " |
| f"file={Path(checkpoint_path).name}, rev={HF_REVISION}" |
| ) |
| try: |
| ensure_checkpoint_from_hf( |
| repo_id=HF_REPO_ID, |
| filename=Path(checkpoint_path).name, |
| local_dir=str(Path(checkpoint_path).parent), |
| revision=HF_REVISION, |
| ) |
| self.logger.info("✓ Successfully downloaded model weights") |
| return True |
| except Exception as e: |
| self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}") |
| return False |
|
|
| def _create_model(self) -> torch.nn.Module: |
| """Create the model architecture.""" |
| if self.model_name == 'ra_miqa': |
| self.logger.info("Building Region-Aware Vision Transformer...") |
| model = RegionVisionTransformer( |
| base_model_name='vit_small_patch16_224', |
| pretrained=False, |
| mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', |
| checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' |
| ) |
| else: |
| try: |
| self.logger.info(f"Building {self.model_name} from PyTorch...") |
| model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1) |
| except Exception: |
| self.logger.info(f"Building {self.model_name} from timm library...") |
| model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1) |
|
|
| return model |
|
|
| def _load_model(self) -> torch.nn.Module: |
| """Load model with pre-trained weights.""" |
| checkpoint_path = self._get_checkpoint_path() |
|
|
| |
| if not self._download_weights(checkpoint_path): |
| raise RuntimeError("Cannot proceed without model weights") |
|
|
| |
| self.logger.info("🔧 Loading model...") |
| model = self._create_model() |
|
|
| |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| state_dict = checkpoint.get('state_dict', checkpoint) |
|
|
| |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k.replace('module.', '') if k.startswith('module.') else k |
| new_state_dict[name] = v |
|
|
| model.load_state_dict(new_state_dict, strict=True) |
| model = model.to(self.device) |
| model.eval() |
|
|
| self.logger.info("✓ Model loaded successfully") |
|
|
| return model |
|
|
| def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose | None]: |
| """ |
| Return preprocessing transforms based on model type. |
| """ |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
| SIMPLE_MEAN = (0.5, 0.5, 0.5) |
| SIMPLE_STD = (0.5, 0.5, 0.5) |
|
|
| |
| transform_imagenet = transforms.Compose([ |
| transforms.Resize(288), |
| transforms.CenterCrop(size=224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| ]) |
|
|
| transform_simple = transforms.Compose([ |
| transforms.Resize(288), |
| transforms.CenterCrop(size=224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) |
| ]) |
|
|
| |
| if any(k in self.model_name for k in ['resnet', 'efficientnet']): |
| return transform_imagenet, None |
|
|
| |
| elif 'vit' in self.model_name: |
| return transform_simple, None |
|
|
| |
| elif 'ra_miqa' in self.model_name: |
| transform_1 = transforms.Compose([ |
| transforms.Resize(288), |
| transforms.CenterCrop(size=224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) |
| ]) |
| transform_2 = transforms.Compose([ |
| transforms.Resize(288), |
| transforms.CenterCrop((288, 288)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| ]) |
| return transform_1, transform_2 |
|
|
| |
| else: |
| print(f"[Warning] Unknown model type '{self.model_name}', using ImageNet normalization.") |
| return transform_imagenet, None |
|
|
| def _prepare_image(self, image_path: str): |
| """ |
| Load and preprocess a single image. |
| Return (img1, img2, original_img) |
| """ |
| img = Image.open(image_path).convert('RGB') |
| img1 = self.transforms1(img).unsqueeze(0) |
| img2 = self.transforms2(img).unsqueeze(0) if self.transforms2 else None |
| return img1, img2, img |
|
|
| @torch.no_grad() |
| def predict_single(self, image_path: str) -> Dict: |
| """ |
| Prediction interface for different backbones. |
| """ |
| img1, img2, original_img = self._prepare_image(image_path) |
|
|
| img1 = img1.to(self.device) |
| if img2 is not None: |
| img2 = img2.to(self.device) |
|
|
| if img2 is None: |
| output = self.model(img1) |
| else: |
| output = self.model(img1, img2) |
|
|
| score = output.item() if torch.is_tensor(output) else float(output) |
|
|
| return { |
| 'image_path': image_path, |
| 'image_name': Path(image_path).name, |
| 'quality_score': score, |
| 'original_image': original_img |
| } |
|
|
| @torch.no_grad() |
| def predict_batch(self, image_paths: List[str], |
| show_progress: bool = True) -> List[Dict]: |
| """ |
| Run inference on multiple images with progress tracking. |
| |
| Args: |
| image_paths: List of paths to image files |
| show_progress: Whether to display progress bar |
| |
| Returns: |
| List of prediction results for each image |
| """ |
| results = [] |
|
|
| |
| iterator = tqdm(image_paths, desc="Processing images", |
| disable=not show_progress, ncols=80) |
|
|
| for img_path in iterator: |
| try: |
| result = self.predict_single(img_path) |
| results.append(result) |
| except Exception as e: |
| self.logger.warning(f"⚠️ Failed to process {img_path}: {str(e)}") |
| results.append({ |
| 'image_path': img_path, |
| 'image_name': Path(img_path).name, |
| 'quality_score': None, |
| 'error': str(e) |
| }) |
|
|
| return results |
|
|
| def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]: |
| """ |
| Main prediction interface - handles both single images and directories. |
| |
| Args: |
| input_path: Path to an image file or directory containing images |
| show_progress: Whether to show progress bar for batch processing |
| |
| Returns: |
| List of prediction results |
| """ |
| input_path = Path(input_path) |
|
|
| |
| if input_path.is_file(): |
| if input_path.suffix.lower() not in SUPPORTED_EXTENSIONS: |
| raise ValueError( |
| f"Unsupported file extension: {input_path.suffix}. " |
| f"Supported: {SUPPORTED_EXTENSIONS}" |
| ) |
| return [self.predict_single(str(input_path))] |
|
|
| |
| elif input_path.is_dir(): |
| |
| image_paths = [] |
| for ext in SUPPORTED_EXTENSIONS: |
| image_paths.extend(input_path.glob(f"*{ext}")) |
| |
|
|
| image_paths = sorted([str(p) for p in image_paths]) |
|
|
| if not image_paths: |
| raise ValueError(f"No supported images found in {input_path}") |
|
|
| self.logger.info(f"📁 Found {len(image_paths)} images in directory") |
| return self.predict_batch(image_paths, show_progress) |
|
|
| else: |
| raise ValueError(f"Input path does not exist: {input_path}") |
|
|
| def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results', |
| score_range: Tuple[float, float] = (0, 100)) -> None: |
| """ |
| Create annotated visualizations of predictions. |
| |
| This generates images with quality scores overlaid in a color-coded box. |
| Low scores appear in red, high scores in green. |
| |
| Args: |
| results: List of prediction results |
| output_dir: Directory to save visualizations |
| score_range: Expected range of quality scores for color normalization |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| self.logger.info(f"\n🎨 Creating visualizations...") |
|
|
| for result in tqdm(results, desc="Generating visualizations", ncols=80): |
| if result.get('quality_score') is None: |
| continue |
|
|
| img = result['original_image'].copy() |
| draw = ImageDraw.Draw(img) |
|
|
| |
| score = result['quality_score'] |
| score_text = f"Quality: {score:.4f}" |
|
|
| |
| norm_score = (score - score_range[0]) / (score_range[1] - score_range[0]) |
| norm_score = max(0, min(1, norm_score)) |
|
|
| |
| if norm_score < 0.5: |
| |
| r = 255 |
| g = int(255 * (norm_score * 2)) |
| b = 0 |
| else: |
| |
| r = int(255 * (2 - norm_score * 2)) |
| g = 255 |
| b = 0 |
|
|
| color = (r, g, b) |
|
|
| |
| box_width = 250 |
| box_height = 50 |
| margin = 10 |
|
|
| |
| box_coords = [margin, margin, margin + box_width, margin + box_height] |
| draw.rectangle(box_coords, fill=color) |
|
|
| |
| try: |
| font = ImageFont.truetype("arial.ttf", 30) |
| except: |
| font = ImageFont.load_default() |
|
|
| |
| bbox = draw.textbbox((0, 0), score_text, font=font) |
| text_width = bbox[2] - bbox[0] |
| text_height = bbox[3] - bbox[1] |
| text_x = margin + (box_width - text_width) // 2 |
| text_y = margin + (box_height - text_height) // 2 |
|
|
| draw.text((text_x, text_y), score_text, fill='black', font=font) |
|
|
| |
| output_file = output_path / f"miqa_{self.model_name}_{result['image_name']}" |
| img.save(output_file) |
|
|
| self.logger.info(f"✓ Visualizations saved to: {output_dir}/") |
|
|
| def save_results(self, results: List[Dict], output_path: str = 'predictions.json', |
| format: str = 'json') -> None: |
| """ |
| Save prediction results to file. |
| |
| Args: |
| results: List of prediction results |
| output_path: Path to save results |
| format: Output format - 'json' or 'csv' |
| """ |
| |
| clean_results = [] |
| for r in results: |
| clean_r = {k: v for k, v in r.items() if k != 'original_image'} |
| clean_results.append(clean_r) |
|
|
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if format == 'json': |
| with open(output_path, 'w') as f: |
| json.dump({ |
| 'metadata': { |
| 'task': self.task, |
| 'model': self.model_name, |
| 'metric_type': self.metric_type, |
| 'timestamp': datetime.now().isoformat(), |
| 'num_images': len(clean_results) |
| }, |
| 'predictions': clean_results |
| }, f, indent=2) |
|
|
| elif format == 'csv': |
| import csv |
| with open(output_path, 'w', newline='') as f: |
| if clean_results: |
| writer = csv.DictWriter(f, fieldnames=clean_results[0].keys()) |
| writer.writeheader() |
| writer.writerows(clean_results) |
|
|
| self.logger.info(f"💾 Results saved to: {output_path}") |
|
|
| def print_summary(self, results: List[Dict]) -> None: |
| """Print a formatted summary of prediction results.""" |
| valid_results = [r for r in results if r.get('quality_score') is not None] |
| failed_results = [r for r in results if r.get('quality_score') is None] |
|
|
| self.logger.info("\n" + "=" * 80) |
| self.logger.info("PREDICTION SUMMARY") |
| self.logger.info("=" * 80) |
|
|
| if valid_results: |
| scores = [r['quality_score'] for r in valid_results] |
| self.logger.info(f"✓ Successfully processed: {len(valid_results)} images") |
| self.logger.info(f" Average quality score: {np.mean(scores):.4f}") |
| self.logger.info(f" Score range: [{np.min(scores):.4f}, {np.max(scores):.4f}]") |
| self.logger.info(f" Standard deviation: {np.std(scores):.4f}") |
|
|
| |
| sorted_results = sorted(valid_results, key=lambda x: x['quality_score'], reverse=True) |
|
|
| self.logger.info("\n🏆 Top 3 quality images:") |
| for i, r in enumerate(sorted_results[:3], 1): |
| self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}") |
|
|
| if len(sorted_results) > 3: |
| self.logger.info("\n⚠️ Bottom 3 quality images:") |
| for i, r in enumerate(sorted_results[-3:], 1): |
| self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}") |
|
|
| if failed_results: |
| self.logger.info(f"\n❌ Failed to process: {len(failed_results)} images") |
|
|
| self.logger.info("=" * 80 + "\n") |
|
|
|
|
| def main(): |
| """Command-line interface for MIQA inference.""" |
| parser = argparse.ArgumentParser( |
| description='MIQA: Machine-centric Image Quality Assessment', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Predict quality of a single image |
| python img_inference.py --input image.jpg --task cls --model ra_miqa |
| |
| # Process all images in a directory |
| python img_inference.py --input ./assets/demo_images/imagenet_demo --task det --model ra_miqa |
| |
| # Save results and create visualizations |
| python img_inference.py --input /assets/demo_images/imagenet_demo --task ins --save-results --visualize |
| """ |
| ) |
|
|
| parser.add_argument('--input', type=str, required=True, |
| help='Path to input image or directory containing images') |
| parser.add_argument('--task', type=str, required=True, |
| choices=['cls', 'det', 'ins'], |
| help='Task type: cls (classification), det (detection), ins (instance)') |
| parser.add_argument('--model', type=str, default='ra_miqa', |
| choices=['ra_miqa'], |
| help='Model architecture (default: ra_miqa; Hub weights are RA-MIQA only)') |
| parser.add_argument('--metric-type', type=str, default='composite', |
| choices=['composite', 'consistency', 'accuracy'], |
| help='Training metric type (default: composite)') |
| parser.add_argument('--device', type=str, default=None, |
| choices=['cuda', 'cpu'], |
| help='Device to run on (auto-detect if not specified)') |
| parser.add_argument('--save-results', action='store_true', |
| help='Save prediction results to file') |
| parser.add_argument('--output-file', type=str, default='predictions.json', |
| help='Output file path for results (default: predictions.json)') |
| parser.add_argument('--output-format', type=str, default='json', |
| choices=['json', 'csv'], |
| help='Output file format (default: json)') |
| parser.add_argument('--visualize', action='store_true', |
| help='Create annotated visualizations of predictions') |
| parser.add_argument('--save-dir', type=str, default='inference_results', |
| help='Directory for save (default: inference_results)') |
| parser.add_argument('--no-progress', action='store_true', |
| help='Disable progress bar') |
|
|
| args = parser.parse_args() |
|
|
| try: |
| |
| miqa = MIQAInference( |
| task=args.task, |
| model_name=args.model, |
| metric_type=args.metric_type, |
| device=args.device |
| ) |
|
|
| |
| results = miqa.predict(args.input, show_progress=not args.no_progress) |
|
|
| |
| miqa.print_summary(results) |
|
|
| args.save_dir = Path(args.save_dir)/ 'image' / args.task / args.metric_type |
| args.output_file = f'miqa_{args.model}_'+args.output_file |
|
|
| |
| if args.save_results: |
| miqa.save_results(results, Path(args.save_dir)/args.output_file, args.output_format) |
|
|
| |
| if args.visualize: |
| miqa.visualize_results(results, args.save_dir) |
|
|
| except Exception as e: |
| print(f"\n❌ Error: {str(e)}", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|