miqa
xiaoqi-wang commited on
Commit
c7e671b
·
verified ·
1 Parent(s): cbf7753

Upload img_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. img_inference.py +599 -0
img_inference.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import argparse
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional, Tuple
8
+ from collections import OrderedDict
9
+ import json
10
+ from datetime import datetime
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ # Image processing imports
15
+ from PIL import Image, ImageDraw, ImageFont
16
+ import torchvision.transforms as transforms
17
+
18
+ # Import your existing model components
19
+ from models.MIQA_base import get_torch_model, get_timm_model
20
+ from models.RA_MIQA import RegionVisionTransformer
21
+ from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES
22
+ from utils.hf_download_utils import ensure_checkpoint_from_hf
23
+
24
+ # Supported image file extensions
25
+ SUPPORTED_EXTENSIONS = {'.jpg', '.jpeg', 'JPEG', '.png', '.bmp', '.tiff', '.tif'}
26
+
27
+
28
+ class MIQAInference:
29
+ """
30
+ Inference wrapper for MIQA models.
31
+
32
+ This class handles model initialization, automatic weight downloading,
33
+ image preprocessing, batch prediction, and result visualization.
34
+ """
35
+
36
+ def __init__(self, task: str, model_name: str = 'ra_miqa',
37
+ metric_type: str = 'composite', device: Optional[str] = None):
38
+ """
39
+ Initialize the MIQA inference system.
40
+
41
+ Args:
42
+ task: Task type - 'cls' (classification), 'det' (detection), or 'ins' (instance)
43
+ model_name: Model architecture to use (default: RA_MIQA for best performance)
44
+ metric_type: Training objective - 'composite', 'consistency', or 'accuracy'
45
+ device: Device to run inference on ('cuda' or 'cpu'). Auto-detects if None.
46
+ """
47
+ self.task = task.lower()
48
+ self.model_name = model_name
49
+ self.metric_type = metric_type
50
+
51
+ # Setup logging with clean formatting
52
+ self.logger = self._setup_logger()
53
+
54
+ # Determine computation device
55
+ if device is None:
56
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+ else:
58
+ self.device = torch.device(device)
59
+
60
+ self.logger.info(f"🚀 Initializing MIQA Inference System")
61
+ self.logger.info(f" Task: {self.task.upper()}")
62
+ self.logger.info(f" Model: {self.model_name}")
63
+ self.logger.info(f" Metric Type: {self.metric_type}")
64
+ self.logger.info(f" Device: {self.device}")
65
+
66
+ # Validate configuration
67
+ self._validate_config()
68
+
69
+ # Initialize model
70
+ self.model = self._load_model()
71
+
72
+ # Setup image preprocessing pipeline
73
+ self.transforms1, self.transforms2 = self._get_transforms()
74
+
75
+ self.logger.info("✅ System ready for inference\n")
76
+
77
+ def _setup_logger(self) -> logging.Logger:
78
+ """Configure logging with both file and console output."""
79
+ logger = logging.getLogger('MIQA_Inference')
80
+ logger.setLevel(logging.INFO)
81
+
82
+ if logger.hasHandlers():
83
+ return logger
84
+
85
+ logger.propagate = False
86
+
87
+ # Console handler with clean formatting
88
+ console_handler = logging.StreamHandler(sys.stdout)
89
+ console_handler.setLevel(logging.INFO)
90
+ console_formatter = logging.Formatter('%(message)s')
91
+ console_handler.setFormatter(console_formatter)
92
+ logger.addHandler(console_handler)
93
+
94
+ return logger
95
+
96
+ def _validate_config(self) -> None:
97
+ """Validate that the requested configuration is supported."""
98
+
99
+ if self.metric_type not in ['composite', 'consistency', 'accuracy']:
100
+ raise ValueError(
101
+ f"Invalid metric_type '{self.metric_type}'. "
102
+ f"Supported: ['composite', 'consistency', 'accuracy']"
103
+ )
104
+
105
+ if self.task not in MODEL_FILENAMES[self.metric_type]:
106
+ raise ValueError(
107
+ f"Invalid task '{self.task}'. "
108
+ f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}"
109
+ )
110
+
111
+ if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]:
112
+ available = list(MODEL_FILENAMES[self.metric_type][self.task].keys())
113
+ raise ValueError(
114
+ f"Model '{self.model_name}' not available for task '{self.task}'. "
115
+ f"Available models: {available}"
116
+ )
117
+ def _get_checkpoint_path(self) -> str:
118
+ """Generate the path where model checkpoint should be stored."""
119
+ base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric'
120
+ base_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name]
123
+ return str(base_dir / filename)
124
+
125
+ def _download_weights(self, checkpoint_path: str) -> bool:
126
+ """
127
+ Download model weights if not present locally.
128
+
129
+ Returns:
130
+ True if weights are available (already existed or successfully downloaded)
131
+ """
132
+ if os.path.exists(checkpoint_path):
133
+ self.logger.info(f"✓ Found cached model weights")
134
+ return True
135
+
136
+ self.logger.info(
137
+ f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, "
138
+ f"file={Path(checkpoint_path).name}, rev={HF_REVISION}"
139
+ )
140
+ try:
141
+ ensure_checkpoint_from_hf(
142
+ repo_id=HF_REPO_ID,
143
+ filename=Path(checkpoint_path).name,
144
+ local_dir=str(Path(checkpoint_path).parent),
145
+ revision=HF_REVISION,
146
+ )
147
+ self.logger.info("✓ Successfully downloaded model weights")
148
+ return True
149
+ except Exception as e:
150
+ self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}")
151
+ return False
152
+
153
+ def _create_model(self) -> torch.nn.Module:
154
+ """Create the model architecture."""
155
+ if self.model_name == 'ra_miqa':
156
+ self.logger.info("Building Region-Aware Vision Transformer...")
157
+ model = RegionVisionTransformer(
158
+ base_model_name='vit_small_patch16_224',
159
+ pretrained=False, # We'll load our trained weights
160
+ mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
161
+ checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth'
162
+ )
163
+ else:
164
+ try:
165
+ self.logger.info(f"Building {self.model_name} from PyTorch...")
166
+ model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1)
167
+ except Exception:
168
+ self.logger.info(f"Building {self.model_name} from timm library...")
169
+ model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1)
170
+
171
+ return model
172
+
173
+ def _load_model(self) -> torch.nn.Module:
174
+ """Load model with pre-trained weights."""
175
+ checkpoint_path = self._get_checkpoint_path()
176
+
177
+ # Ensure weights are available
178
+ if not self._download_weights(checkpoint_path):
179
+ raise RuntimeError("Cannot proceed without model weights")
180
+
181
+ # Create model architecture
182
+ self.logger.info("🔧 Loading model...")
183
+ model = self._create_model()
184
+
185
+ # Load weights
186
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
187
+ state_dict = checkpoint.get('state_dict', checkpoint)
188
+
189
+ # Remove 'module.' prefix if present (from DataParallel training)
190
+ new_state_dict = OrderedDict()
191
+ for k, v in state_dict.items():
192
+ name = k.replace('module.', '') if k.startswith('module.') else k
193
+ new_state_dict[name] = v
194
+
195
+ model.load_state_dict(new_state_dict, strict=True)
196
+ model = model.to(self.device)
197
+ model.eval() # Set to evaluation mode
198
+
199
+ self.logger.info("✓ Model loaded successfully")
200
+
201
+ return model
202
+
203
+ def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose | None]:
204
+ """
205
+ Return preprocessing transforms based on model type.
206
+ """
207
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
208
+ IMAGENET_STD = (0.229, 0.224, 0.225)
209
+ SIMPLE_MEAN = (0.5, 0.5, 0.5)
210
+ SIMPLE_STD = (0.5, 0.5, 0.5)
211
+
212
+ # Default (for single-input backbones)
213
+ transform_imagenet = transforms.Compose([
214
+ transforms.Resize(288),
215
+ transforms.CenterCrop(size=224),
216
+ transforms.ToTensor(),
217
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
218
+ ])
219
+
220
+ transform_simple = transforms.Compose([
221
+ transforms.Resize(288),
222
+ transforms.CenterCrop(size=224),
223
+ transforms.ToTensor(),
224
+ transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD)
225
+ ])
226
+
227
+ # 1️⃣ CNNs(ResNet / EfficientNet)
228
+ if any(k in self.model_name for k in ['resnet', 'efficientnet']):
229
+ return transform_imagenet, None
230
+
231
+ # 2️⃣ ViT
232
+ elif 'vit' in self.model_name:
233
+ return transform_simple, None
234
+
235
+ # 3️⃣ ra_miqa
236
+ elif 'ra_miqa' in self.model_name:
237
+ transform_1 = transforms.Compose([
238
+ transforms.Resize(288),
239
+ transforms.CenterCrop(size=224),
240
+ transforms.ToTensor(),
241
+ transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD)
242
+ ])
243
+ transform_2 = transforms.Compose([
244
+ transforms.Resize(288),
245
+ transforms.CenterCrop((288, 288)),
246
+ transforms.ToTensor(),
247
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
248
+ ])
249
+ return transform_1, transform_2
250
+
251
+ # fallback
252
+ else:
253
+ print(f"[Warning] Unknown model type '{self.model_name}', using ImageNet normalization.")
254
+ return transform_imagenet, None
255
+
256
+ def _prepare_image(self, image_path: str):
257
+ """
258
+ Load and preprocess a single image.
259
+ Return (img1, img2, original_img)
260
+ """
261
+ img = Image.open(image_path).convert('RGB')
262
+ img1 = self.transforms1(img).unsqueeze(0)
263
+ img2 = self.transforms2(img).unsqueeze(0) if self.transforms2 else None
264
+ return img1, img2, img
265
+
266
+ @torch.no_grad()
267
+ def predict_single(self, image_path: str) -> Dict:
268
+ """
269
+ Prediction interface for different backbones.
270
+ """
271
+ img1, img2, original_img = self._prepare_image(image_path)
272
+
273
+ img1 = img1.to(self.device)
274
+ if img2 is not None:
275
+ img2 = img2.to(self.device)
276
+
277
+ if img2 is None:
278
+ output = self.model(img1)
279
+ else:
280
+ output = self.model(img1, img2)
281
+
282
+ score = output.item() if torch.is_tensor(output) else float(output)
283
+
284
+ return {
285
+ 'image_path': image_path,
286
+ 'image_name': Path(image_path).name,
287
+ 'quality_score': score,
288
+ 'original_image': original_img
289
+ }
290
+
291
+ @torch.no_grad()
292
+ def predict_batch(self, image_paths: List[str],
293
+ show_progress: bool = True) -> List[Dict]:
294
+ """
295
+ Run inference on multiple images with progress tracking.
296
+
297
+ Args:
298
+ image_paths: List of paths to image files
299
+ show_progress: Whether to display progress bar
300
+
301
+ Returns:
302
+ List of prediction results for each image
303
+ """
304
+ results = []
305
+
306
+ # Create progress bar if requested
307
+ iterator = tqdm(image_paths, desc="Processing images",
308
+ disable=not show_progress, ncols=80)
309
+
310
+ for img_path in iterator:
311
+ try:
312
+ result = self.predict_single(img_path)
313
+ results.append(result)
314
+ except Exception as e:
315
+ self.logger.warning(f"⚠️ Failed to process {img_path}: {str(e)}")
316
+ results.append({
317
+ 'image_path': img_path,
318
+ 'image_name': Path(img_path).name,
319
+ 'quality_score': None,
320
+ 'error': str(e)
321
+ })
322
+
323
+ return results
324
+
325
+ def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]:
326
+ """
327
+ Main prediction interface - handles both single images and directories.
328
+
329
+ Args:
330
+ input_path: Path to an image file or directory containing images
331
+ show_progress: Whether to show progress bar for batch processing
332
+
333
+ Returns:
334
+ List of prediction results
335
+ """
336
+ input_path = Path(input_path)
337
+
338
+ # Handle single file
339
+ if input_path.is_file():
340
+ if input_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
341
+ raise ValueError(
342
+ f"Unsupported file extension: {input_path.suffix}. "
343
+ f"Supported: {SUPPORTED_EXTENSIONS}"
344
+ )
345
+ return [self.predict_single(str(input_path))]
346
+
347
+ # Handle directory
348
+ elif input_path.is_dir():
349
+ # Find all supported images in directory
350
+ image_paths = []
351
+ for ext in SUPPORTED_EXTENSIONS:
352
+ image_paths.extend(input_path.glob(f"*{ext}"))
353
+ # image_paths.extend(input_path.glob(f"*{ext.upper()}"))
354
+
355
+ image_paths = sorted([str(p) for p in image_paths])
356
+
357
+ if not image_paths:
358
+ raise ValueError(f"No supported images found in {input_path}")
359
+
360
+ self.logger.info(f"📁 Found {len(image_paths)} images in directory")
361
+ return self.predict_batch(image_paths, show_progress)
362
+
363
+ else:
364
+ raise ValueError(f"Input path does not exist: {input_path}")
365
+
366
+ def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results',
367
+ score_range: Tuple[float, float] = (0, 100)) -> None:
368
+ """
369
+ Create annotated visualizations of predictions.
370
+
371
+ This generates images with quality scores overlaid in a color-coded box.
372
+ Low scores appear in red, high scores in green.
373
+
374
+ Args:
375
+ results: List of prediction results
376
+ output_dir: Directory to save visualizations
377
+ score_range: Expected range of quality scores for color normalization
378
+ """
379
+ output_path = Path(output_dir)
380
+ output_path.mkdir(parents=True, exist_ok=True)
381
+
382
+ self.logger.info(f"\n🎨 Creating visualizations...")
383
+
384
+ for result in tqdm(results, desc="Generating visualizations", ncols=80):
385
+ if result.get('quality_score') is None:
386
+ continue # Skip failed predictions
387
+
388
+ img = result['original_image'].copy()
389
+ draw = ImageDraw.Draw(img)
390
+
391
+ # Prepare score display
392
+ score = result['quality_score']
393
+ score_text = f"Quality: {score:.4f}"
394
+
395
+ # Normalize score to [0, 1] for color interpolation
396
+ norm_score = (score - score_range[0]) / (score_range[1] - score_range[0])
397
+ norm_score = max(0, min(1, norm_score)) # Clamp to [0, 1]
398
+
399
+ # Color interpolation: red (low quality) -> yellow -> green (high quality)
400
+ if norm_score < 0.5:
401
+ # Red to yellow
402
+ r = 255
403
+ g = int(255 * (norm_score * 2))
404
+ b = 0
405
+ else:
406
+ # Yellow to green
407
+ r = int(255 * (2 - norm_score * 2))
408
+ g = 255
409
+ b = 0
410
+
411
+ color = (r, g, b)
412
+
413
+ # Draw colored box with score
414
+ box_width = 250
415
+ box_height = 50
416
+ margin = 10
417
+
418
+ # Position in top-left corner
419
+ box_coords = [margin, margin, margin + box_width, margin + box_height]
420
+ draw.rectangle(box_coords, fill=color)
421
+
422
+ # Add text (try to load a nice font, fall back to default)
423
+ try:
424
+ font = ImageFont.truetype("arial.ttf", 30)
425
+ except:
426
+ font = ImageFont.load_default()
427
+
428
+ # Calculate text position to center it in the box
429
+ bbox = draw.textbbox((0, 0), score_text, font=font)
430
+ text_width = bbox[2] - bbox[0]
431
+ text_height = bbox[3] - bbox[1]
432
+ text_x = margin + (box_width - text_width) // 2
433
+ text_y = margin + (box_height - text_height) // 2
434
+
435
+ draw.text((text_x, text_y), score_text, fill='black', font=font)
436
+
437
+ # Save annotated image
438
+ output_file = output_path / f"miqa_{self.model_name}_{result['image_name']}"
439
+ img.save(output_file)
440
+
441
+ self.logger.info(f"✓ Visualizations saved to: {output_dir}/")
442
+
443
+ def save_results(self, results: List[Dict], output_path: str = 'predictions.json',
444
+ format: str = 'json') -> None:
445
+ """
446
+ Save prediction results to file.
447
+
448
+ Args:
449
+ results: List of prediction results
450
+ output_path: Path to save results
451
+ format: Output format - 'json' or 'csv'
452
+ """
453
+ # Remove PIL image objects before saving
454
+ clean_results = []
455
+ for r in results:
456
+ clean_r = {k: v for k, v in r.items() if k != 'original_image'}
457
+ clean_results.append(clean_r)
458
+
459
+ output_path = Path(output_path)
460
+ output_path.parent.mkdir(parents=True, exist_ok=True)
461
+
462
+ if format == 'json':
463
+ with open(output_path, 'w') as f:
464
+ json.dump({
465
+ 'metadata': {
466
+ 'task': self.task,
467
+ 'model': self.model_name,
468
+ 'metric_type': self.metric_type,
469
+ 'timestamp': datetime.now().isoformat(),
470
+ 'num_images': len(clean_results)
471
+ },
472
+ 'predictions': clean_results
473
+ }, f, indent=2)
474
+
475
+ elif format == 'csv':
476
+ import csv
477
+ with open(output_path, 'w', newline='') as f:
478
+ if clean_results:
479
+ writer = csv.DictWriter(f, fieldnames=clean_results[0].keys())
480
+ writer.writeheader()
481
+ writer.writerows(clean_results)
482
+
483
+ self.logger.info(f"💾 Results saved to: {output_path}")
484
+
485
+ def print_summary(self, results: List[Dict]) -> None:
486
+ """Print a formatted summary of prediction results."""
487
+ valid_results = [r for r in results if r.get('quality_score') is not None]
488
+ failed_results = [r for r in results if r.get('quality_score') is None]
489
+
490
+ self.logger.info("\n" + "=" * 80)
491
+ self.logger.info("PREDICTION SUMMARY")
492
+ self.logger.info("=" * 80)
493
+
494
+ if valid_results:
495
+ scores = [r['quality_score'] for r in valid_results]
496
+ self.logger.info(f"✓ Successfully processed: {len(valid_results)} images")
497
+ self.logger.info(f" Average quality score: {np.mean(scores):.4f}")
498
+ self.logger.info(f" Score range: [{np.min(scores):.4f}, {np.max(scores):.4f}]")
499
+ self.logger.info(f" Standard deviation: {np.std(scores):.4f}")
500
+
501
+ # Show top and bottom quality images
502
+ sorted_results = sorted(valid_results, key=lambda x: x['quality_score'], reverse=True)
503
+
504
+ self.logger.info("\n🏆 Top 3 quality images:")
505
+ for i, r in enumerate(sorted_results[:3], 1):
506
+ self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}")
507
+
508
+ if len(sorted_results) > 3:
509
+ self.logger.info("\n⚠️ Bottom 3 quality images:")
510
+ for i, r in enumerate(sorted_results[-3:], 1):
511
+ self.logger.info(f" {i}. {r['image_name']}: {r['quality_score']:.4f}")
512
+
513
+ if failed_results:
514
+ self.logger.info(f"\n❌ Failed to process: {len(failed_results)} images")
515
+
516
+ self.logger.info("=" * 80 + "\n")
517
+
518
+
519
+ def main():
520
+ """Command-line interface for MIQA inference."""
521
+ parser = argparse.ArgumentParser(
522
+ description='MIQA: Machine-centric Image Quality Assessment',
523
+ formatter_class=argparse.RawDescriptionHelpFormatter,
524
+ epilog="""
525
+ Examples:
526
+ # Predict quality of a single image
527
+ python img_inference.py --input image.jpg --task cls --model ra_miqa
528
+
529
+ # Process all images in a directory
530
+ python img_inference.py --input ./assets/demo_images/imagenet_demo --task det --model ra_miqa
531
+
532
+ # Save results and create visualizations
533
+ python img_inference.py --input /assets/demo_images/imagenet_demo --task ins --save-results --visualize
534
+ """
535
+ )
536
+
537
+ parser.add_argument('--input', type=str, required=True,
538
+ help='Path to input image or directory containing images')
539
+ parser.add_argument('--task', type=str, required=True,
540
+ choices=['cls', 'det', 'ins'],
541
+ help='Task type: cls (classification), det (detection), ins (instance)')
542
+ parser.add_argument('--model', type=str, default='ra_miqa',
543
+ choices=['ra_miqa'],
544
+ help='Model architecture (default: ra_miqa; Hub weights are RA-MIQA only)')
545
+ parser.add_argument('--metric-type', type=str, default='composite',
546
+ choices=['composite', 'consistency', 'accuracy'],
547
+ help='Training metric type (default: composite)')
548
+ parser.add_argument('--device', type=str, default=None,
549
+ choices=['cuda', 'cpu'],
550
+ help='Device to run on (auto-detect if not specified)')
551
+ parser.add_argument('--save-results', action='store_true',
552
+ help='Save prediction results to file')
553
+ parser.add_argument('--output-file', type=str, default='predictions.json',
554
+ help='Output file path for results (default: predictions.json)')
555
+ parser.add_argument('--output-format', type=str, default='json',
556
+ choices=['json', 'csv'],
557
+ help='Output file format (default: json)')
558
+ parser.add_argument('--visualize', action='store_true',
559
+ help='Create annotated visualizations of predictions')
560
+ parser.add_argument('--save-dir', type=str, default='inference_results',
561
+ help='Directory for save (default: inference_results)')
562
+ parser.add_argument('--no-progress', action='store_true',
563
+ help='Disable progress bar')
564
+
565
+ args = parser.parse_args()
566
+
567
+ try:
568
+ # Initialize inference system
569
+ miqa = MIQAInference(
570
+ task=args.task,
571
+ model_name=args.model,
572
+ metric_type=args.metric_type,
573
+ device=args.device
574
+ )
575
+
576
+ # Run predictions
577
+ results = miqa.predict(args.input, show_progress=not args.no_progress)
578
+
579
+ # Print summary
580
+ miqa.print_summary(results)
581
+
582
+ args.save_dir = Path(args.save_dir)/ 'image' / args.task / args.metric_type
583
+ args.output_file = f'miqa_{args.model}_'+args.output_file
584
+
585
+ # Save results if requested
586
+ if args.save_results:
587
+ miqa.save_results(results, Path(args.save_dir)/args.output_file, args.output_format)
588
+
589
+ # Create visualizations if requested
590
+ if args.visualize:
591
+ miqa.visualize_results(results, args.save_dir)
592
+
593
+ except Exception as e:
594
+ print(f"\n❌ Error: {str(e)}", file=sys.stderr)
595
+ sys.exit(1)
596
+
597
+
598
+ if __name__ == '__main__':
599
+ main()