miqa
xiaoqi-wang commited on
Commit
a7301d6
·
verified ·
1 Parent(s): 26ee8b6

Upload video_analytics_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. video_analytics_inference.py +883 -0
video_analytics_inference.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import argparse
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional, Tuple
8
+ from collections import OrderedDict, defaultdict
9
+ import json
10
+ from datetime import datetime
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import cv2
14
+
15
+ # Image processing imports
16
+ from PIL import Image, ImageDraw, ImageFont
17
+ import torchvision.transforms as transforms
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib
20
+
21
+ matplotlib.use('Agg') # Use non-interactive backend
22
+
23
+ # Import your existing model components
24
+ from models.MIQA_base import get_torch_model, get_timm_model
25
+ from models.RA_MIQA import RegionVisionTransformer
26
+ from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES
27
+ from utils.hf_download_utils import ensure_checkpoint_from_hf
28
+
29
+ # Supported file extensions
30
+ SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.JPEG', '.png', '.bmp', '.tiff', '.tif'}
31
+ SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'}
32
+
33
+
34
+ class VideoFrameExtractor:
35
+ """
36
+ Extracts and samples frames from video files intelligently.
37
+
38
+ This class handles different sampling strategies to balance between
39
+ thoroughness and computational efficiency.
40
+ """
41
+
42
+ def __init__(self, sampling_strategy: str = 'uniform',
43
+ target_frames: int = 30,
44
+ fps_sample: Optional[float] = None):
45
+ """
46
+ Initialize the frame extractor.
47
+
48
+ Args:
49
+ sampling_strategy: How to sample frames - 'uniform', 'fps', or 'keyframe'
50
+ target_frames: Target number of frames to extract (for uniform sampling)
51
+ fps_sample: Sample rate in frames per second (for fps sampling)
52
+ """
53
+ self.sampling_strategy = sampling_strategy
54
+ self.target_frames = target_frames
55
+ self.fps_sample = fps_sample
56
+
57
+ def extract_frames(self, video_path: str) -> Tuple[List[np.ndarray], List[float], Dict]:
58
+ """
59
+ Extract frames from video based on sampling strategy.
60
+
61
+ Returns:
62
+ Tuple of (frames_list, timestamps_list, video_metadata)
63
+ """
64
+ cap = cv2.VideoCapture(video_path)
65
+
66
+ if not cap.isOpened():
67
+ raise ValueError(f"Cannot open video file: {video_path}")
68
+
69
+ # Get video properties
70
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
71
+ fps = cap.get(cv2.CAP_PROP_FPS)
72
+ duration = total_frames / fps if fps > 0 else 0
73
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
74
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
75
+
76
+ metadata = {
77
+ 'total_frames': total_frames,
78
+ 'fps': fps,
79
+ 'duration': duration,
80
+ 'width': width,
81
+ 'height': height
82
+ }
83
+
84
+ # Determine which frames to sample
85
+ frame_indices = self._get_sample_indices(total_frames, fps)
86
+
87
+ frames = []
88
+ timestamps = []
89
+
90
+ for idx in frame_indices:
91
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
92
+ ret, frame = cap.read()
93
+
94
+ if ret:
95
+ # Convert BGR to RGB (OpenCV uses BGR)
96
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
97
+ frames.append(frame_rgb)
98
+ # Calculate timestamp in seconds
99
+ timestamp = idx / fps if fps > 0 else idx
100
+ timestamps.append(timestamp)
101
+
102
+ cap.release()
103
+
104
+ return frames, timestamps, metadata
105
+
106
+ def _get_sample_indices(self, total_frames: int, fps: float) -> List[int]:
107
+ """
108
+ Determine which frame indices to sample based on strategy.
109
+ """
110
+ if self.sampling_strategy == 'uniform':
111
+ # Sample frames uniformly across the video
112
+ if total_frames <= self.target_frames:
113
+ return list(range(total_frames))
114
+ else:
115
+ # Calculate step size to get approximately target_frames
116
+ step = total_frames / self.target_frames
117
+ indices = [int(i * step) for i in range(self.target_frames)]
118
+ return indices
119
+
120
+ elif self.sampling_strategy == 'fps':
121
+ # Sample at a specific frame rate
122
+ if self.fps_sample is None:
123
+ raise ValueError("fps_sample must be specified for fps sampling strategy")
124
+
125
+ frame_interval = max(1, int(fps / self.fps_sample))
126
+ indices = list(range(0, total_frames, frame_interval))
127
+ return indices
128
+
129
+ else:
130
+ raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}")
131
+
132
+
133
+ def aggregate_scores_by_second(frame_results: List[Dict]) -> List[Dict]:
134
+ """
135
+ Aggregate frame-level quality scores to per-second averages.
136
+
137
+ This function groups all frames that fall within the same second
138
+ and computes their average quality score. This provides a smoothed
139
+ view of quality over time, reducing noise from frame-to-frame variations.
140
+
141
+ Args:
142
+ frame_results: List of dictionaries with 'timestamp' and 'quality_score'
143
+
144
+ Returns:
145
+ List of dictionaries with per-second aggregated scores
146
+ """
147
+ # Group frames by their second (floor of timestamp)
148
+ seconds_data = defaultdict(list)
149
+
150
+ for frame in frame_results:
151
+ second = int(frame['timestamp']) # Floor to nearest second
152
+ seconds_data[second].append(frame['quality_score'])
153
+
154
+ # Calculate average for each second
155
+ per_second_results = []
156
+ for second in sorted(seconds_data.keys()):
157
+ scores = seconds_data[second]
158
+ per_second_results.append({
159
+ 'second': second,
160
+ 'timestamp': float(second), # Use second as timestamp for plotting
161
+ 'quality_score': np.mean(scores),
162
+ 'min_score': np.min(scores),
163
+ 'max_score': np.max(scores),
164
+ 'num_frames': len(scores),
165
+ 'std_score': np.std(scores) if len(scores) > 1 else 0.0
166
+ })
167
+
168
+ return per_second_results
169
+
170
+
171
+ class MIQAInference:
172
+ """
173
+ Inference wrapper for MIQA models supporting both images and videos.
174
+ """
175
+
176
+ def __init__(self, task: str, model_name: str = 'ra_miqa',
177
+ metric_type: str = 'composite', device: Optional[str] = None,
178
+ video_sampling: str = 'uniform', video_target_frames: int = 30):
179
+ """
180
+ Initialize the MIQA inference system.
181
+
182
+ Args:
183
+ task: Task type - 'cls', 'det', or 'ins'
184
+ model_name: Model architecture to use
185
+ metric_type: Training objective - 'composite', 'consistency', or 'accuracy'
186
+ device: Device to run inference on
187
+ video_sampling: Frame sampling strategy for videos
188
+ video_target_frames: Target number of frames to extract from videos
189
+ """
190
+ self.task = task.lower()
191
+ self.model_name = model_name
192
+ self.metric_type = metric_type
193
+ self.video_target_frames = video_target_frames
194
+
195
+ # Setup logging
196
+ self.logger = self._setup_logger()
197
+
198
+ # Determine device
199
+ if device is None:
200
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
201
+ else:
202
+ self.device = torch.device(device)
203
+
204
+ self.logger.info(f"🚀 Initializing MIQA Inference System")
205
+ self.logger.info(f" Task: {self.task.upper()}")
206
+ self.logger.info(f" Model: {self.model_name}")
207
+ self.logger.info(f" Device: {self.device}")
208
+
209
+ # Validate configuration
210
+ self._validate_config()
211
+
212
+ # Initialize model
213
+ self.model = self._load_model()
214
+
215
+ # Setup image preprocessing
216
+ self.transforms1, self.transforms2 = self._get_transforms()
217
+
218
+ # Initialize video frame extractor
219
+ self.frame_extractor = VideoFrameExtractor(
220
+ sampling_strategy=video_sampling,
221
+ target_frames=video_target_frames
222
+ )
223
+
224
+ self.logger.info("✅ System ready for inference\n")
225
+
226
+ def _setup_logger(self) -> logging.Logger:
227
+ """Configure logging with both file and console output."""
228
+ logger = logging.getLogger('MIQA_Inference')
229
+ logger.setLevel(logging.INFO)
230
+
231
+ if logger.hasHandlers():
232
+ return logger
233
+
234
+ logger.propagate = False
235
+
236
+ # Console handler with clean formatting
237
+ console_handler = logging.StreamHandler(sys.stdout)
238
+ console_handler.setLevel(logging.INFO)
239
+ console_formatter = logging.Formatter('%(message)s')
240
+ console_handler.setFormatter(console_formatter)
241
+ logger.addHandler(console_handler)
242
+
243
+ return logger
244
+
245
+ def _validate_config(self) -> None:
246
+ """Validate that the requested configuration is supported."""
247
+
248
+ if self.metric_type not in ['composite', 'consistency', 'accuracy']:
249
+ raise ValueError(
250
+ f"Invalid metric_type '{self.metric_type}'. "
251
+ f"Supported: ['composite', 'consistency', 'accuracy']"
252
+ )
253
+
254
+ if self.task not in MODEL_FILENAMES[self.metric_type]:
255
+ raise ValueError(
256
+ f"Invalid task '{self.task}'. "
257
+ f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}"
258
+ )
259
+
260
+ if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]:
261
+ available = list(MODEL_FILENAMES[self.metric_type][self.task].keys())
262
+ raise ValueError(
263
+ f"Model '{self.model_name}' not available for task '{self.task}'. "
264
+ f"Available models: {available}"
265
+ )
266
+
267
+ def _get_checkpoint_path(self) -> str:
268
+ """Generate the path where model checkpoint should be stored."""
269
+ base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric'
270
+ base_dir.mkdir(parents=True, exist_ok=True)
271
+
272
+ filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name]
273
+ return str(base_dir / filename)
274
+
275
+ def _download_weights(self, checkpoint_path: str) -> bool:
276
+ """
277
+ Download model weights if not present locally.
278
+
279
+ Returns:
280
+ True if weights are available (already existed or successfully downloaded)
281
+ """
282
+ if os.path.exists(checkpoint_path):
283
+ self.logger.info(f"✓ Found cached model weights")
284
+ return True
285
+
286
+ self.logger.info(
287
+ f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, "
288
+ f"file={Path(checkpoint_path).name}, rev={HF_REVISION}"
289
+ )
290
+ try:
291
+ ensure_checkpoint_from_hf(
292
+ repo_id=HF_REPO_ID,
293
+ filename=Path(checkpoint_path).name,
294
+ local_dir=str(Path(checkpoint_path).parent),
295
+ revision=HF_REVISION,
296
+ )
297
+ self.logger.info("✓ Successfully downloaded model weights")
298
+ return True
299
+ except Exception as e:
300
+ self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}")
301
+ return False
302
+
303
+ def _create_model(self) -> torch.nn.Module:
304
+ """Create the model architecture."""
305
+ if self.model_name == 'ra_miqa':
306
+ self.logger.info("Building Region-Aware Vision Transformer...")
307
+ model = RegionVisionTransformer(
308
+ base_model_name='vit_small_patch16_224',
309
+ pretrained=False, # We'll load our trained weights
310
+ mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
311
+ checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth'
312
+ )
313
+ else:
314
+ try:
315
+ self.logger.info(f"Building {self.model_name} from PyTorch...")
316
+ model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1)
317
+ except Exception:
318
+ self.logger.info(f"Building {self.model_name} from timm library...")
319
+ model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1)
320
+
321
+ return model
322
+
323
+ def _load_model(self) -> torch.nn.Module:
324
+ """Load model with weights."""
325
+ checkpoint_path = self._get_checkpoint_path()
326
+
327
+ if not self._download_weights(checkpoint_path):
328
+ raise RuntimeError("Cannot proceed without model weights")
329
+
330
+ self.logger.info("🔧 Loading model...")
331
+ model = self._create_model()
332
+
333
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
334
+ state_dict = checkpoint.get('state_dict', checkpoint)
335
+
336
+ new_state_dict = OrderedDict()
337
+ for k, v in state_dict.items():
338
+ name = k.replace('module.', '') if k.startswith('module.') else k
339
+ new_state_dict[name] = v
340
+
341
+ model.load_state_dict(new_state_dict, strict=True)
342
+ model = model.to(self.device)
343
+ model.eval()
344
+
345
+ self.logger.info("✓ Model loaded successfully")
346
+ return model
347
+
348
+ def _get_transforms(self) -> [transforms.Compose, transforms.Compose]:
349
+ """
350
+ Get image preprocessing transforms.
351
+
352
+ These transforms normalize images to match the training distribution.
353
+ """
354
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
355
+ IMAGENET_STD = (0.229, 0.224, 0.225)
356
+ SIMPLE_MEAN = (0.5, 0.5, 0.5)
357
+ SIMPLE_STD = (0.5, 0.5, 0.5)
358
+
359
+ transforms_list1 = [
360
+ transforms.Resize(288),
361
+ transforms.CenterCrop(size=224),
362
+ transforms.ToTensor(),
363
+ transforms.Normalize(mean=SIMPLE_MEAN,
364
+ std=SIMPLE_STD)
365
+ ]
366
+ transform_list_2 = [
367
+ transforms.Resize(288),
368
+ transforms.CenterCrop((288, 288)),
369
+ transforms.ToTensor(),
370
+ transforms.Normalize(mean=IMAGENET_MEAN,
371
+ std=IMAGENET_STD)
372
+ ]
373
+ return transforms.Compose(transforms_list1), transforms.Compose(transform_list_2)
374
+
375
+ def _prepare_frame(self, frame: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
376
+ """
377
+ Preprocess a video frame for model input.
378
+
379
+ Args:
380
+ frame: Numpy array in RGB format
381
+
382
+ Returns:
383
+ Tuple of (cropped_tensor, resized_tensor)
384
+ """
385
+ # Convert numpy array to PIL Image
386
+ img = Image.fromarray(frame)
387
+
388
+ # Apply transforms
389
+ img1 = self.transforms1(img).unsqueeze(0)
390
+ img2 = self.transforms2(img).unsqueeze(0)
391
+
392
+ return img1, img2
393
+
394
+ def _prepare_image(self, image_path: str) -> Tuple[torch.Tensor, torch.Tensor, Image.Image]:
395
+ """Load and preprocess an image file."""
396
+ img = Image.open(image_path).convert('RGB')
397
+ img1 = self.transforms1(img).unsqueeze(0)
398
+ img2 = self.transforms2(img).unsqueeze(0)
399
+ return img1, img2, img
400
+
401
+ @torch.no_grad()
402
+ def predict_single_image(self, image_path: str) -> Dict:
403
+ """Run inference on a single image."""
404
+ img_cropped, img_resized, original_img = self._prepare_image(image_path)
405
+
406
+ img_cropped = img_cropped.to(self.device)
407
+ img_resized = img_resized.to(self.device)
408
+
409
+ output = self.model(img_cropped, img_resized)
410
+ score = output.item()
411
+
412
+ return {
413
+ 'image_path': image_path,
414
+ 'image_name': Path(image_path).name,
415
+ 'quality_score': score,
416
+ 'original_image': original_img,
417
+ 'type': 'image'
418
+ }
419
+
420
+ @torch.no_grad()
421
+ def predict_video(self, video_path: str, show_progress: bool = True) -> Dict:
422
+ """
423
+ Run inference on a video file.
424
+
425
+ This extracts frames, predicts quality for each, aggregates to per-second
426
+ averages, and returns comprehensive time-series data suitable for visualization.
427
+ """
428
+ self.logger.info(f"🎬 Processing video: {Path(video_path).name}")
429
+
430
+ # Extract frames
431
+ frames, timestamps, metadata = self.frame_extractor.extract_frames(video_path)
432
+
433
+ self.logger.info(f" Extracted {len(frames)} frames from {metadata['duration']:.1f}s video")
434
+
435
+ # Process each frame
436
+ frame_results = []
437
+ iterator = tqdm(frames, desc="Analyzing frames", disable=not show_progress, ncols=80)
438
+
439
+ for frame, timestamp in zip(iterator, timestamps):
440
+ img_cropped, img_resized = self._prepare_frame(frame)
441
+
442
+ img_cropped = img_cropped.to(self.device)
443
+ img_resized = img_resized.to(self.device)
444
+
445
+ output = self.model(img_cropped, img_resized)
446
+ score = output.item()
447
+
448
+ frame_results.append({
449
+ 'timestamp': timestamp,
450
+ 'quality_score': score,
451
+ 'frame': frame # Store for visualization if needed
452
+ })
453
+
454
+ # Aggregate frame results by second
455
+ per_second_results = aggregate_scores_by_second(frame_results)
456
+
457
+ # Calculate statistics from per-second data for better representation
458
+ second_scores = [r['quality_score'] for r in per_second_results]
459
+
460
+ return {
461
+ 'video_path': video_path,
462
+ 'video_name': Path(video_path).name,
463
+ 'type': 'video',
464
+ 'metadata': metadata,
465
+ 'frame_results': frame_results,
466
+ 'per_second_results': per_second_results, # NEW: Added per-second aggregation
467
+ 'num_frames_analyzed': len(frame_results),
468
+ 'num_seconds': len(per_second_results), # NEW: Number of unique seconds
469
+ 'average_quality': np.mean(second_scores),
470
+ 'min_quality': np.min(second_scores),
471
+ 'max_quality': np.max(second_scores),
472
+ 'std_quality': np.std(second_scores)
473
+ }
474
+
475
+ def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]:
476
+ """
477
+ Main prediction interface - handles images, videos, and directories.
478
+ """
479
+ input_path = Path(input_path)
480
+
481
+ # Handle single file
482
+ if input_path.is_file():
483
+ ext = input_path.suffix.lower()
484
+
485
+ if ext in SUPPORTED_IMAGE_EXTENSIONS:
486
+ return [self.predict_single_image(str(input_path))]
487
+ elif ext in SUPPORTED_VIDEO_EXTENSIONS:
488
+ return [self.predict_video(str(input_path), show_progress)]
489
+ else:
490
+ raise ValueError(f"Unsupported file extension: {ext}")
491
+
492
+ # Handle directory
493
+ elif input_path.is_dir():
494
+ results = []
495
+
496
+ # Find all supported files
497
+ image_paths = []
498
+ video_paths = []
499
+
500
+ for ext in SUPPORTED_IMAGE_EXTENSIONS:
501
+ image_paths.extend(input_path.glob(f"*{ext}"))
502
+
503
+ for ext in SUPPORTED_VIDEO_EXTENSIONS:
504
+ video_paths.extend(input_path.glob(f"*{ext}"))
505
+
506
+ image_paths = sorted([str(p) for p in image_paths])
507
+ video_paths = sorted([str(p) for p in video_paths])
508
+
509
+ if not image_paths and not video_paths:
510
+ raise ValueError(f"No supported files found in {input_path}")
511
+
512
+ self.logger.info(f"📁 Found {len(image_paths)} images and {len(video_paths)} videos")
513
+
514
+ # Process images
515
+ if image_paths:
516
+ for img_path in tqdm(image_paths, desc="Processing images", ncols=80):
517
+ try:
518
+ result = self.predict_single_image(img_path)
519
+ results.append(result)
520
+ except Exception as e:
521
+ self.logger.warning(f"⚠️ Failed: {img_path}")
522
+
523
+ # Process videos
524
+ if video_paths:
525
+ for vid_path in video_paths:
526
+ try:
527
+ result = self.predict_video(vid_path, show_progress)
528
+ results.append(result)
529
+ except Exception as e:
530
+ self.logger.warning(f"⚠️ Failed: {vid_path}")
531
+
532
+ return results
533
+ else:
534
+ raise ValueError(f"Input path does not exist: {input_path}")
535
+
536
+ def visualize_video_results(self, video_result: Dict, output_dir: str = 'inference_results',
537
+ granularity: str = 'second') -> None:
538
+ """
539
+ Create time-series visualization for video quality predictions.
540
+
541
+ This generates a line plot showing how quality varies across the video timeline.
542
+ You can choose between frame-level and second-level granularity.
543
+
544
+ Args:
545
+ video_result: Dictionary containing video analysis results
546
+ output_dir: Directory to save visualizations
547
+ granularity: Visualization granularity - 'frame' for frame-by-frame,
548
+ 'second' for per-second averages, or 'both' for dual plot
549
+ """
550
+ output_path = Path(output_dir)
551
+ output_path.mkdir(parents=True, exist_ok=True)
552
+
553
+ if granularity == 'both':
554
+ # Create side-by-side comparison
555
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
556
+
557
+ # Frame-level plot (left)
558
+ frame_results = video_result['frame_results']
559
+ frame_timestamps = [r['timestamp'] for r in frame_results]
560
+ frame_scores = [r['quality_score'] for r in frame_results]
561
+
562
+ ax1.plot(frame_timestamps, frame_scores, linewidth=1.5, color='#2E86AB',
563
+ marker='o', markersize=3, markerfacecolor='white', markeredgewidth=1,
564
+ alpha=0.7, label='Frame-level')
565
+ ax1.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold')
566
+ ax1.set_ylabel('Quality Score', fontsize=11, fontweight='bold')
567
+ ax1.set_title('Frame-Level Quality', fontsize=12, fontweight='bold')
568
+ ax1.grid(True, alpha=0.3, linestyle='--')
569
+ ax1.set_ylim(0, 1)
570
+ # Second-level plot (right)
571
+ second_results = video_result['per_second_results']
572
+ second_timestamps = [r['timestamp'] for r in second_results]
573
+ second_scores = [r['quality_score'] for r in second_results]
574
+
575
+ ax2.plot(second_timestamps, second_scores, linewidth=2.5, color='#A23B72',
576
+ marker='s', markersize=6, markerfacecolor='white', markeredgewidth=1.5,
577
+ label='Per-second average')
578
+
579
+ # Add error bars showing variability within each second
580
+ if second_results and 'std_score' in second_results[0]:
581
+ stds = [r['std_score'] for r in second_results]
582
+ ax2.fill_between(second_timestamps,
583
+ np.array(second_scores) - np.array(stds),
584
+ np.array(second_scores) + np.array(stds),
585
+ alpha=0.2, color='#A23B72')
586
+
587
+ ax2.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold')
588
+ ax2.set_ylabel('Quality Score', fontsize=11, fontweight='bold')
589
+ ax2.set_title('Per-Second Averaged Quality', fontsize=12, fontweight='bold')
590
+ ax2.grid(True, alpha=0.3, linestyle='--')
591
+ ax2.set_ylim(0, 1)
592
+ # Add overall average line to both
593
+ avg_score = video_result['average_quality']
594
+ ax1.axhline(y=avg_score, color='#F18F01', linestyle='--',
595
+ linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}')
596
+ ax2.axhline(y=avg_score, color='#F18F01', linestyle='--',
597
+ linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}')
598
+
599
+ ax1.legend(loc='best', framealpha=0.9)
600
+ ax2.legend(loc='best', framealpha=0.9)
601
+
602
+ plt.suptitle(f"Video Quality Analysis: {video_result['video_name']}, {self.task}-oriented MIQA",
603
+ fontsize=14, fontweight='bold', y=1.02)
604
+
605
+ suffix = 'comparison'
606
+
607
+ else:
608
+ # Single plot based on selected granularity
609
+ plt.figure(figsize=(14, 6))
610
+
611
+ if granularity == 'frame':
612
+ frame_results = video_result['frame_results']
613
+ timestamps = [r['timestamp'] for r in frame_results]
614
+ scores = [r['quality_score'] for r in frame_results]
615
+ plot_color = '#2E86AB'
616
+ plot_label = 'Frame-level quality'
617
+ title_suffix = '(Frame-Level)'
618
+ suffix = 'frame'
619
+ marker_size = 4
620
+ else: # second
621
+ second_results = video_result['per_second_results']
622
+ timestamps = [r['timestamp'] for r in second_results]
623
+ scores = [r['quality_score'] for r in second_results]
624
+ plot_color = '#A23B72'
625
+ plot_label = 'Per-second average'
626
+ title_suffix = '(Per-Second Average)'
627
+ suffix = 'second'
628
+ marker_size = 6
629
+
630
+ # Main quality plot
631
+ plt.plot(timestamps, scores, linewidth=2, color=plot_color, marker='o',
632
+ markersize=marker_size, markerfacecolor='white', markeredgewidth=1.5,
633
+ label=plot_label)
634
+
635
+ # Add shaded region for second-level showing variability
636
+ if granularity == 'second' and second_results and 'std_score' in second_results[0]:
637
+ stds = [r['std_score'] for r in second_results]
638
+ plt.fill_between(timestamps,
639
+ np.array(scores) - np.array(stds),
640
+ np.array(scores) + np.array(stds),
641
+ alpha=0.2, color=plot_color)
642
+
643
+ # Add average line
644
+ avg_score = video_result['average_quality']
645
+ plt.axhline(y=avg_score, color='#F18F01', linestyle='--',
646
+ linewidth=1.5, label=f'Average: {avg_score:.2f}')
647
+
648
+ # Styling
649
+ plt.xlabel('Time (seconds)', fontsize=12, fontweight='bold')
650
+ plt.ylabel('Quality Score', fontsize=12, fontweight='bold')
651
+ plt.title(f"Video Quality Analysis: {video_result['video_name']} {title_suffix}",
652
+ fontsize=14, fontweight='bold', pad=20)
653
+ plt.grid(True, alpha=0.3, linestyle='--')
654
+ plt.legend(loc='best', framealpha=0.9)
655
+
656
+ # Add statistics box
657
+ if granularity == 'both':
658
+ stats_text = (
659
+ f"Duration: {video_result['metadata']['duration']:.1f}s\n"
660
+ f"Frames: {video_result['num_frames_analyzed']} | "
661
+ f"Seconds: {video_result['num_seconds']}\n"
662
+ f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n"
663
+ f"Std Dev: {video_result['std_quality']:.2f}"
664
+ )
665
+ # Add to the right subplot
666
+ ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes,
667
+ fontsize=9, verticalalignment='top',
668
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
669
+ else:
670
+ stats_text = (
671
+ f"Duration: {video_result['metadata']['duration']:.1f}s\n"
672
+ f"Frames Analyzed: {video_result['num_frames_analyzed']}\n"
673
+ f"Unique Seconds: {video_result['num_seconds']}\n"
674
+ f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n"
675
+ f"Std Dev: {video_result['std_quality']:.2f}"
676
+ )
677
+ plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
678
+ fontsize=12, verticalalignment='top',
679
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.6))
680
+
681
+ plt.tight_layout()
682
+
683
+ # Save figure
684
+ output_file = output_path / f"{Path(video_result['video_name']).stem}_{self.metric_type}_quality_{suffix}.png"
685
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
686
+ plt.close()
687
+
688
+ self.logger.info(f" Saved visualization: {output_file.name}")
689
+
690
+ def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results',
691
+ video_granularity: str = 'second') -> None:
692
+ """
693
+ Create visualizations for all results (images and videos).
694
+
695
+ Args:
696
+ results: List of prediction results
697
+ output_dir: Directory to save visualizations
698
+ video_granularity: For videos - 'frame', 'second', or 'both'
699
+ """
700
+ output_path = Path(output_dir)
701
+ output_path.mkdir(parents=True, exist_ok=True)
702
+
703
+ self.logger.info(f"\n🎨 Creating visualizations (granularity: {video_granularity})...")
704
+
705
+ for result in results:
706
+ if result['type'] == 'video':
707
+ self.visualize_video_results(result, output_dir, video_granularity)
708
+ elif result['type'] == 'image' and result.get('quality_score') is not None:
709
+ # Use original image visualization logic
710
+ img = result['original_image'].copy()
711
+ draw = ImageDraw.Draw(img)
712
+
713
+ score = result['quality_score']
714
+ score_text = f"Quality: {score:.3f}"
715
+
716
+ # Simple color coding (adjust range as needed)
717
+ # norm_score = score / 100.0
718
+ norm_score = max(0, score)
719
+
720
+ if norm_score < 0.5:
721
+ r, g, b = 255, int(255 * norm_score * 2), 0
722
+ else:
723
+ r, g, b = int(255 * (2 - norm_score * 2)), 255, 0
724
+
725
+ color = (r, g, b)
726
+
727
+ box_coords = [10, 10, 260, 60]
728
+ draw.rectangle(box_coords, fill=color)
729
+
730
+ try:
731
+ font = ImageFont.truetype("arial.ttf", 24)
732
+ except:
733
+ font = ImageFont.load_default()
734
+
735
+ draw.text((15, 20), score_text, fill='black', font=font)
736
+
737
+ output_file = output_path / f"annotated_{result['image_name']}"
738
+ img.save(output_file)
739
+
740
+ self.logger.info(f"✓ Visualizations saved to: {output_dir}/")
741
+
742
+ def save_results(self, results: List[Dict], output_path: str = 'predictions.json') -> None:
743
+ """Save prediction results to JSON file."""
744
+ # Clean results for JSON serialization
745
+ clean_results = []
746
+ for r in results:
747
+ clean_r = {k: v for k, v in r.items()
748
+ if k not in ['original_image', 'frame']}
749
+
750
+ # For video results, remove frame data but keep scores
751
+ if clean_r.get('type') == 'video' and 'frame_results' in clean_r:
752
+ clean_r['frame_results'] = [
753
+ {k: v for k, v in fr.items() if k != 'frame'}
754
+ for fr in clean_r['frame_results']
755
+ ]
756
+
757
+ clean_results.append(clean_r)
758
+
759
+ output_path = Path(output_path)
760
+ output_path.parent.mkdir(parents=True, exist_ok=True)
761
+
762
+ with open(output_path, 'w') as f:
763
+ json.dump({
764
+ 'metadata': {
765
+ 'task': self.task,
766
+ 'model': self.model_name,
767
+ 'metric_type': self.metric_type,
768
+ 'timestamp': datetime.now().isoformat(),
769
+ 'total_files': len(clean_results)
770
+ },
771
+ 'predictions': clean_results
772
+ }, f, indent=2)
773
+
774
+ self.logger.info(f"💾 Results saved to: {output_path}")
775
+
776
+ def print_summary(self, results: List[Dict]) -> None:
777
+ """Print formatted summary of prediction results."""
778
+ self.logger.info("\n" + "=" * 80)
779
+ self.logger.info("PREDICTION SUMMARY")
780
+ self.logger.info("=" * 80)
781
+
782
+ image_results = [r for r in results if r.get('type') == 'image']
783
+ video_results = [r for r in results if r.get('type') == 'video']
784
+
785
+ if image_results:
786
+ valid_images = [r for r in image_results if r.get('quality_score') is not None]
787
+ if valid_images:
788
+ scores = [r['quality_score'] for r in valid_images]
789
+ self.logger.info(f"\n📸 Image Analysis ({len(valid_images)} images)")
790
+ self.logger.info(f" Average quality: {np.mean(scores):.2f}")
791
+ self.logger.info(f" Score range: [{np.min(scores):.2f}, {np.max(scores):.2f}]")
792
+
793
+ if video_results:
794
+ self.logger.info(f"\n🎬 Video Analysis ({len(video_results)} videos)")
795
+ for vr in video_results:
796
+ self.logger.info(f"\n {vr['video_name']}:")
797
+ self.logger.info(f" Duration: {vr['metadata']['duration']:.1f}s")
798
+ self.logger.info(f" Frames analyzed: {vr['num_frames_analyzed']}")
799
+ self.logger.info(f" Unique seconds: {vr['num_seconds']}")
800
+ self.logger.info(f" Average quality (per-second): {vr['average_quality']:.2f}")
801
+ self.logger.info(f" Quality range: [{vr['min_quality']:.2f}, {vr['max_quality']:.2f}]")
802
+ self.logger.info(f" Variability (std): {vr['std_quality']:.2f}")
803
+
804
+ self.logger.info("\n" + "=" * 80 + "\n")
805
+
806
+
807
+ def main():
808
+ """Command-line interface for MIQA inference."""
809
+ parser = argparse.ArgumentParser(
810
+ description='MIQA: Machine-centric Image and Video Quality Assessment',
811
+ formatter_class=argparse.RawDescriptionHelpFormatter,
812
+ epilog="""
813
+ Examples:
814
+ # Analyze video with per-second visualization
815
+ python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity second
816
+
817
+ # Analyze video with both frame and second visualizations
818
+ python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity both
819
+
820
+ # Process directory with frame-level visualization
821
+ python video_analytics_inference.py --input ./assets/demo_video --task det --video-frames 120 --visualize --viz-granularity second --metric-type consistency
822
+ """
823
+ )
824
+
825
+ parser.add_argument('--input', type=str, required=True,
826
+ help='Path to input image/video or directory')
827
+ parser.add_argument('--task', type=str, required=True,
828
+ choices=['cls', 'det', 'ins'],
829
+ help='Task type')
830
+ parser.add_argument('--model', type=str, default='ra_miqa',
831
+ choices=['ra_miqa'],
832
+ help='Model architecture (RA-MIQA only; matches Hub registry)')
833
+ parser.add_argument('--metric-type', type=str, default='composite',
834
+ choices=['composite', 'consistency', 'accuracy'],
835
+ help='Training metric type')
836
+ parser.add_argument('--device', type=str, default=None,
837
+ choices=['cuda', 'cpu'],
838
+ help='Device to run on')
839
+ parser.add_argument('--video-frames', type=int, default=50,
840
+ help='Target number of frames to sample from videos')
841
+ parser.add_argument('--save-results', action='store_true',
842
+ help='Save prediction results to file')
843
+ parser.add_argument('--output-file', type=str, default='predictions.json',
844
+ help='Output file path')
845
+ parser.add_argument('--visualize', action='store_true',
846
+ help='Create visualizations')
847
+ parser.add_argument('--viz-dir', type=str, default='inference_results',
848
+ help='Directory for visualizations')
849
+ parser.add_argument('--viz-granularity', type=str, default='second',
850
+ choices=['frame', 'second', 'both'],
851
+ help='Visualization granularity for videos: frame-level, per-second, or both')
852
+ parser.add_argument('--no-progress', action='store_true',
853
+ help='Disable progress bar')
854
+
855
+ args = parser.parse_args()
856
+
857
+ try:
858
+ miqa = MIQAInference(
859
+ task=args.task,
860
+ model_name=args.model,
861
+ metric_type=args.metric_type,
862
+ device=args.device,
863
+ video_target_frames=args.video_frames
864
+ )
865
+
866
+ results = miqa.predict(args.input, show_progress=not args.no_progress)
867
+ miqa.print_summary(results)
868
+
869
+ if args.save_results:
870
+ miqa.save_results(results, args.output_file)
871
+
872
+ if args.visualize:
873
+ miqa.visualize_results(results, args.viz_dir, video_granularity=args.viz_granularity)
874
+
875
+ except Exception as e:
876
+ print(f"\n❌ Error: {str(e)}", file=sys.stderr)
877
+ import traceback
878
+ traceback.print_exc()
879
+ sys.exit(1)
880
+
881
+
882
+ if __name__ == '__main__':
883
+ main()