| import os |
| import sys |
| import torch |
| import argparse |
| import logging |
| from pathlib import Path |
| from typing import List, Dict, Optional, Tuple |
| from collections import OrderedDict, defaultdict |
| import json |
| from datetime import datetime |
| from tqdm import tqdm |
| import numpy as np |
| import cv2 |
|
|
| |
| from PIL import Image, ImageDraw, ImageFont |
| import torchvision.transforms as transforms |
| import matplotlib.pyplot as plt |
| import matplotlib |
|
|
| matplotlib.use('Agg') |
|
|
| |
| from models.MIQA_base import get_torch_model, get_timm_model |
| from models.RA_MIQA import RegionVisionTransformer |
| from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES |
| from utils.hf_download_utils import ensure_checkpoint_from_hf |
|
|
| |
| SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.JPEG', '.png', '.bmp', '.tiff', '.tif'} |
| SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} |
|
|
|
|
| class VideoFrameExtractor: |
| """ |
| Extracts and samples frames from video files intelligently. |
| |
| This class handles different sampling strategies to balance between |
| thoroughness and computational efficiency. |
| """ |
|
|
| def __init__(self, sampling_strategy: str = 'uniform', |
| target_frames: int = 30, |
| fps_sample: Optional[float] = None): |
| """ |
| Initialize the frame extractor. |
| |
| Args: |
| sampling_strategy: How to sample frames - 'uniform', 'fps', or 'keyframe' |
| target_frames: Target number of frames to extract (for uniform sampling) |
| fps_sample: Sample rate in frames per second (for fps sampling) |
| """ |
| self.sampling_strategy = sampling_strategy |
| self.target_frames = target_frames |
| self.fps_sample = fps_sample |
|
|
| def extract_frames(self, video_path: str) -> Tuple[List[np.ndarray], List[float], Dict]: |
| """ |
| Extract frames from video based on sampling strategy. |
| |
| Returns: |
| Tuple of (frames_list, timestamps_list, video_metadata) |
| """ |
| cap = cv2.VideoCapture(video_path) |
|
|
| if not cap.isOpened(): |
| raise ValueError(f"Cannot open video file: {video_path}") |
|
|
| |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| duration = total_frames / fps if fps > 0 else 0 |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| metadata = { |
| 'total_frames': total_frames, |
| 'fps': fps, |
| 'duration': duration, |
| 'width': width, |
| 'height': height |
| } |
|
|
| |
| frame_indices = self._get_sample_indices(total_frames, fps) |
|
|
| frames = [] |
| timestamps = [] |
|
|
| for idx in frame_indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
|
|
| if ret: |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append(frame_rgb) |
| |
| timestamp = idx / fps if fps > 0 else idx |
| timestamps.append(timestamp) |
|
|
| cap.release() |
|
|
| return frames, timestamps, metadata |
|
|
| def _get_sample_indices(self, total_frames: int, fps: float) -> List[int]: |
| """ |
| Determine which frame indices to sample based on strategy. |
| """ |
| if self.sampling_strategy == 'uniform': |
| |
| if total_frames <= self.target_frames: |
| return list(range(total_frames)) |
| else: |
| |
| step = total_frames / self.target_frames |
| indices = [int(i * step) for i in range(self.target_frames)] |
| return indices |
|
|
| elif self.sampling_strategy == 'fps': |
| |
| if self.fps_sample is None: |
| raise ValueError("fps_sample must be specified for fps sampling strategy") |
|
|
| frame_interval = max(1, int(fps / self.fps_sample)) |
| indices = list(range(0, total_frames, frame_interval)) |
| return indices |
|
|
| else: |
| raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}") |
|
|
|
|
| def aggregate_scores_by_second(frame_results: List[Dict]) -> List[Dict]: |
| """ |
| Aggregate frame-level quality scores to per-second averages. |
| |
| This function groups all frames that fall within the same second |
| and computes their average quality score. This provides a smoothed |
| view of quality over time, reducing noise from frame-to-frame variations. |
| |
| Args: |
| frame_results: List of dictionaries with 'timestamp' and 'quality_score' |
| |
| Returns: |
| List of dictionaries with per-second aggregated scores |
| """ |
| |
| seconds_data = defaultdict(list) |
|
|
| for frame in frame_results: |
| second = int(frame['timestamp']) |
| seconds_data[second].append(frame['quality_score']) |
|
|
| |
| per_second_results = [] |
| for second in sorted(seconds_data.keys()): |
| scores = seconds_data[second] |
| per_second_results.append({ |
| 'second': second, |
| 'timestamp': float(second), |
| 'quality_score': np.mean(scores), |
| 'min_score': np.min(scores), |
| 'max_score': np.max(scores), |
| 'num_frames': len(scores), |
| 'std_score': np.std(scores) if len(scores) > 1 else 0.0 |
| }) |
|
|
| return per_second_results |
|
|
|
|
| class MIQAInference: |
| """ |
| Inference wrapper for MIQA models supporting both images and videos. |
| """ |
|
|
| def __init__(self, task: str, model_name: str = 'ra_miqa', |
| metric_type: str = 'composite', device: Optional[str] = None, |
| video_sampling: str = 'uniform', video_target_frames: int = 30): |
| """ |
| Initialize the MIQA inference system. |
| |
| Args: |
| task: Task type - 'cls', 'det', or 'ins' |
| model_name: Model architecture to use |
| metric_type: Training objective - 'composite', 'consistency', or 'accuracy' |
| device: Device to run inference on |
| video_sampling: Frame sampling strategy for videos |
| video_target_frames: Target number of frames to extract from videos |
| """ |
| self.task = task.lower() |
| self.model_name = model_name |
| self.metric_type = metric_type |
| self.video_target_frames = video_target_frames |
|
|
| |
| self.logger = self._setup_logger() |
|
|
| |
| if device is None: |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| self.device = torch.device(device) |
|
|
| self.logger.info(f"🚀 Initializing MIQA Inference System") |
| self.logger.info(f" Task: {self.task.upper()}") |
| self.logger.info(f" Model: {self.model_name}") |
| self.logger.info(f" Device: {self.device}") |
|
|
| |
| self._validate_config() |
|
|
| |
| self.model = self._load_model() |
|
|
| |
| self.transforms1, self.transforms2 = self._get_transforms() |
|
|
| |
| self.frame_extractor = VideoFrameExtractor( |
| sampling_strategy=video_sampling, |
| target_frames=video_target_frames |
| ) |
|
|
| self.logger.info("✅ System ready for inference\n") |
|
|
| def _setup_logger(self) -> logging.Logger: |
| """Configure logging with both file and console output.""" |
| logger = logging.getLogger('MIQA_Inference') |
| logger.setLevel(logging.INFO) |
|
|
| if logger.hasHandlers(): |
| return logger |
|
|
| logger.propagate = False |
|
|
| |
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setLevel(logging.INFO) |
| console_formatter = logging.Formatter('%(message)s') |
| console_handler.setFormatter(console_formatter) |
| logger.addHandler(console_handler) |
|
|
| return logger |
|
|
| def _validate_config(self) -> None: |
| """Validate that the requested configuration is supported.""" |
|
|
| if self.metric_type not in ['composite', 'consistency', 'accuracy']: |
| raise ValueError( |
| f"Invalid metric_type '{self.metric_type}'. " |
| f"Supported: ['composite', 'consistency', 'accuracy']" |
| ) |
|
|
| if self.task not in MODEL_FILENAMES[self.metric_type]: |
| raise ValueError( |
| f"Invalid task '{self.task}'. " |
| f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}" |
| ) |
|
|
| if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]: |
| available = list(MODEL_FILENAMES[self.metric_type][self.task].keys()) |
| raise ValueError( |
| f"Model '{self.model_name}' not available for task '{self.task}'. " |
| f"Available models: {available}" |
| ) |
|
|
| def _get_checkpoint_path(self) -> str: |
| """Generate the path where model checkpoint should be stored.""" |
| base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric' |
| base_dir.mkdir(parents=True, exist_ok=True) |
|
|
| filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name] |
| return str(base_dir / filename) |
|
|
| def _download_weights(self, checkpoint_path: str) -> bool: |
| """ |
| Download model weights if not present locally. |
| |
| Returns: |
| True if weights are available (already existed or successfully downloaded) |
| """ |
| if os.path.exists(checkpoint_path): |
| self.logger.info(f"✓ Found cached model weights") |
| return True |
|
|
| self.logger.info( |
| f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, " |
| f"file={Path(checkpoint_path).name}, rev={HF_REVISION}" |
| ) |
| try: |
| ensure_checkpoint_from_hf( |
| repo_id=HF_REPO_ID, |
| filename=Path(checkpoint_path).name, |
| local_dir=str(Path(checkpoint_path).parent), |
| revision=HF_REVISION, |
| ) |
| self.logger.info("✓ Successfully downloaded model weights") |
| return True |
| except Exception as e: |
| self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}") |
| return False |
|
|
| def _create_model(self) -> torch.nn.Module: |
| """Create the model architecture.""" |
| if self.model_name == 'ra_miqa': |
| self.logger.info("Building Region-Aware Vision Transformer...") |
| model = RegionVisionTransformer( |
| base_model_name='vit_small_patch16_224', |
| pretrained=False, |
| mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', |
| checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' |
| ) |
| else: |
| try: |
| self.logger.info(f"Building {self.model_name} from PyTorch...") |
| model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1) |
| except Exception: |
| self.logger.info(f"Building {self.model_name} from timm library...") |
| model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1) |
|
|
| return model |
|
|
| def _load_model(self) -> torch.nn.Module: |
| """Load model with weights.""" |
| checkpoint_path = self._get_checkpoint_path() |
|
|
| if not self._download_weights(checkpoint_path): |
| raise RuntimeError("Cannot proceed without model weights") |
|
|
| self.logger.info("🔧 Loading model...") |
| model = self._create_model() |
|
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| state_dict = checkpoint.get('state_dict', checkpoint) |
|
|
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| name = k.replace('module.', '') if k.startswith('module.') else k |
| new_state_dict[name] = v |
|
|
| model.load_state_dict(new_state_dict, strict=True) |
| model = model.to(self.device) |
| model.eval() |
|
|
| self.logger.info("✓ Model loaded successfully") |
| return model |
|
|
| def _get_transforms(self) -> [transforms.Compose, transforms.Compose]: |
| """ |
| Get image preprocessing transforms. |
| |
| These transforms normalize images to match the training distribution. |
| """ |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
| SIMPLE_MEAN = (0.5, 0.5, 0.5) |
| SIMPLE_STD = (0.5, 0.5, 0.5) |
|
|
| transforms_list1 = [ |
| transforms.Resize(288), |
| transforms.CenterCrop(size=224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=SIMPLE_MEAN, |
| std=SIMPLE_STD) |
| ] |
| transform_list_2 = [ |
| transforms.Resize(288), |
| transforms.CenterCrop((288, 288)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, |
| std=IMAGENET_STD) |
| ] |
| return transforms.Compose(transforms_list1), transforms.Compose(transform_list_2) |
|
|
| def _prepare_frame(self, frame: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Preprocess a video frame for model input. |
| |
| Args: |
| frame: Numpy array in RGB format |
| |
| Returns: |
| Tuple of (cropped_tensor, resized_tensor) |
| """ |
| |
| img = Image.fromarray(frame) |
|
|
| |
| img1 = self.transforms1(img).unsqueeze(0) |
| img2 = self.transforms2(img).unsqueeze(0) |
|
|
| return img1, img2 |
|
|
| def _prepare_image(self, image_path: str) -> Tuple[torch.Tensor, torch.Tensor, Image.Image]: |
| """Load and preprocess an image file.""" |
| img = Image.open(image_path).convert('RGB') |
| img1 = self.transforms1(img).unsqueeze(0) |
| img2 = self.transforms2(img).unsqueeze(0) |
| return img1, img2, img |
|
|
| @torch.no_grad() |
| def predict_single_image(self, image_path: str) -> Dict: |
| """Run inference on a single image.""" |
| img_cropped, img_resized, original_img = self._prepare_image(image_path) |
|
|
| img_cropped = img_cropped.to(self.device) |
| img_resized = img_resized.to(self.device) |
|
|
| output = self.model(img_cropped, img_resized) |
| score = output.item() |
|
|
| return { |
| 'image_path': image_path, |
| 'image_name': Path(image_path).name, |
| 'quality_score': score, |
| 'original_image': original_img, |
| 'type': 'image' |
| } |
|
|
| @torch.no_grad() |
| def predict_video(self, video_path: str, show_progress: bool = True) -> Dict: |
| """ |
| Run inference on a video file. |
| |
| This extracts frames, predicts quality for each, aggregates to per-second |
| averages, and returns comprehensive time-series data suitable for visualization. |
| """ |
| self.logger.info(f"🎬 Processing video: {Path(video_path).name}") |
|
|
| |
| frames, timestamps, metadata = self.frame_extractor.extract_frames(video_path) |
|
|
| self.logger.info(f" Extracted {len(frames)} frames from {metadata['duration']:.1f}s video") |
|
|
| |
| frame_results = [] |
| iterator = tqdm(frames, desc="Analyzing frames", disable=not show_progress, ncols=80) |
|
|
| for frame, timestamp in zip(iterator, timestamps): |
| img_cropped, img_resized = self._prepare_frame(frame) |
|
|
| img_cropped = img_cropped.to(self.device) |
| img_resized = img_resized.to(self.device) |
|
|
| output = self.model(img_cropped, img_resized) |
| score = output.item() |
|
|
| frame_results.append({ |
| 'timestamp': timestamp, |
| 'quality_score': score, |
| 'frame': frame |
| }) |
|
|
| |
| per_second_results = aggregate_scores_by_second(frame_results) |
|
|
| |
| second_scores = [r['quality_score'] for r in per_second_results] |
|
|
| return { |
| 'video_path': video_path, |
| 'video_name': Path(video_path).name, |
| 'type': 'video', |
| 'metadata': metadata, |
| 'frame_results': frame_results, |
| 'per_second_results': per_second_results, |
| 'num_frames_analyzed': len(frame_results), |
| 'num_seconds': len(per_second_results), |
| 'average_quality': np.mean(second_scores), |
| 'min_quality': np.min(second_scores), |
| 'max_quality': np.max(second_scores), |
| 'std_quality': np.std(second_scores) |
| } |
|
|
| def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]: |
| """ |
| Main prediction interface - handles images, videos, and directories. |
| """ |
| input_path = Path(input_path) |
|
|
| |
| if input_path.is_file(): |
| ext = input_path.suffix.lower() |
|
|
| if ext in SUPPORTED_IMAGE_EXTENSIONS: |
| return [self.predict_single_image(str(input_path))] |
| elif ext in SUPPORTED_VIDEO_EXTENSIONS: |
| return [self.predict_video(str(input_path), show_progress)] |
| else: |
| raise ValueError(f"Unsupported file extension: {ext}") |
|
|
| |
| elif input_path.is_dir(): |
| results = [] |
|
|
| |
| image_paths = [] |
| video_paths = [] |
|
|
| for ext in SUPPORTED_IMAGE_EXTENSIONS: |
| image_paths.extend(input_path.glob(f"*{ext}")) |
|
|
| for ext in SUPPORTED_VIDEO_EXTENSIONS: |
| video_paths.extend(input_path.glob(f"*{ext}")) |
|
|
| image_paths = sorted([str(p) for p in image_paths]) |
| video_paths = sorted([str(p) for p in video_paths]) |
|
|
| if not image_paths and not video_paths: |
| raise ValueError(f"No supported files found in {input_path}") |
|
|
| self.logger.info(f"📁 Found {len(image_paths)} images and {len(video_paths)} videos") |
|
|
| |
| if image_paths: |
| for img_path in tqdm(image_paths, desc="Processing images", ncols=80): |
| try: |
| result = self.predict_single_image(img_path) |
| results.append(result) |
| except Exception as e: |
| self.logger.warning(f"⚠️ Failed: {img_path}") |
|
|
| |
| if video_paths: |
| for vid_path in video_paths: |
| try: |
| result = self.predict_video(vid_path, show_progress) |
| results.append(result) |
| except Exception as e: |
| self.logger.warning(f"⚠️ Failed: {vid_path}") |
|
|
| return results |
| else: |
| raise ValueError(f"Input path does not exist: {input_path}") |
|
|
| def visualize_video_results(self, video_result: Dict, output_dir: str = 'inference_results', |
| granularity: str = 'second') -> None: |
| """ |
| Create time-series visualization for video quality predictions. |
| |
| This generates a line plot showing how quality varies across the video timeline. |
| You can choose between frame-level and second-level granularity. |
| |
| Args: |
| video_result: Dictionary containing video analysis results |
| output_dir: Directory to save visualizations |
| granularity: Visualization granularity - 'frame' for frame-by-frame, |
| 'second' for per-second averages, or 'both' for dual plot |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| if granularity == 'both': |
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6)) |
|
|
| |
| frame_results = video_result['frame_results'] |
| frame_timestamps = [r['timestamp'] for r in frame_results] |
| frame_scores = [r['quality_score'] for r in frame_results] |
|
|
| ax1.plot(frame_timestamps, frame_scores, linewidth=1.5, color='#2E86AB', |
| marker='o', markersize=3, markerfacecolor='white', markeredgewidth=1, |
| alpha=0.7, label='Frame-level') |
| ax1.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold') |
| ax1.set_ylabel('Quality Score', fontsize=11, fontweight='bold') |
| ax1.set_title('Frame-Level Quality', fontsize=12, fontweight='bold') |
| ax1.grid(True, alpha=0.3, linestyle='--') |
| ax1.set_ylim(0, 1) |
| |
| second_results = video_result['per_second_results'] |
| second_timestamps = [r['timestamp'] for r in second_results] |
| second_scores = [r['quality_score'] for r in second_results] |
|
|
| ax2.plot(second_timestamps, second_scores, linewidth=2.5, color='#A23B72', |
| marker='s', markersize=6, markerfacecolor='white', markeredgewidth=1.5, |
| label='Per-second average') |
|
|
| |
| if second_results and 'std_score' in second_results[0]: |
| stds = [r['std_score'] for r in second_results] |
| ax2.fill_between(second_timestamps, |
| np.array(second_scores) - np.array(stds), |
| np.array(second_scores) + np.array(stds), |
| alpha=0.2, color='#A23B72') |
|
|
| ax2.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold') |
| ax2.set_ylabel('Quality Score', fontsize=11, fontweight='bold') |
| ax2.set_title('Per-Second Averaged Quality', fontsize=12, fontweight='bold') |
| ax2.grid(True, alpha=0.3, linestyle='--') |
| ax2.set_ylim(0, 1) |
| |
| avg_score = video_result['average_quality'] |
| ax1.axhline(y=avg_score, color='#F18F01', linestyle='--', |
| linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}') |
| ax2.axhline(y=avg_score, color='#F18F01', linestyle='--', |
| linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}') |
|
|
| ax1.legend(loc='best', framealpha=0.9) |
| ax2.legend(loc='best', framealpha=0.9) |
|
|
| plt.suptitle(f"Video Quality Analysis: {video_result['video_name']}, {self.task}-oriented MIQA", |
| fontsize=14, fontweight='bold', y=1.02) |
|
|
| suffix = 'comparison' |
|
|
| else: |
| |
| plt.figure(figsize=(14, 6)) |
|
|
| if granularity == 'frame': |
| frame_results = video_result['frame_results'] |
| timestamps = [r['timestamp'] for r in frame_results] |
| scores = [r['quality_score'] for r in frame_results] |
| plot_color = '#2E86AB' |
| plot_label = 'Frame-level quality' |
| title_suffix = '(Frame-Level)' |
| suffix = 'frame' |
| marker_size = 4 |
| else: |
| second_results = video_result['per_second_results'] |
| timestamps = [r['timestamp'] for r in second_results] |
| scores = [r['quality_score'] for r in second_results] |
| plot_color = '#A23B72' |
| plot_label = 'Per-second average' |
| title_suffix = '(Per-Second Average)' |
| suffix = 'second' |
| marker_size = 6 |
|
|
| |
| plt.plot(timestamps, scores, linewidth=2, color=plot_color, marker='o', |
| markersize=marker_size, markerfacecolor='white', markeredgewidth=1.5, |
| label=plot_label) |
|
|
| |
| if granularity == 'second' and second_results and 'std_score' in second_results[0]: |
| stds = [r['std_score'] for r in second_results] |
| plt.fill_between(timestamps, |
| np.array(scores) - np.array(stds), |
| np.array(scores) + np.array(stds), |
| alpha=0.2, color=plot_color) |
|
|
| |
| avg_score = video_result['average_quality'] |
| plt.axhline(y=avg_score, color='#F18F01', linestyle='--', |
| linewidth=1.5, label=f'Average: {avg_score:.2f}') |
|
|
| |
| plt.xlabel('Time (seconds)', fontsize=12, fontweight='bold') |
| plt.ylabel('Quality Score', fontsize=12, fontweight='bold') |
| plt.title(f"Video Quality Analysis: {video_result['video_name']} {title_suffix}", |
| fontsize=14, fontweight='bold', pad=20) |
| plt.grid(True, alpha=0.3, linestyle='--') |
| plt.legend(loc='best', framealpha=0.9) |
|
|
| |
| if granularity == 'both': |
| stats_text = ( |
| f"Duration: {video_result['metadata']['duration']:.1f}s\n" |
| f"Frames: {video_result['num_frames_analyzed']} | " |
| f"Seconds: {video_result['num_seconds']}\n" |
| f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n" |
| f"Std Dev: {video_result['std_quality']:.2f}" |
| ) |
| |
| ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes, |
| fontsize=9, verticalalignment='top', |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) |
| else: |
| stats_text = ( |
| f"Duration: {video_result['metadata']['duration']:.1f}s\n" |
| f"Frames Analyzed: {video_result['num_frames_analyzed']}\n" |
| f"Unique Seconds: {video_result['num_seconds']}\n" |
| f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n" |
| f"Std Dev: {video_result['std_quality']:.2f}" |
| ) |
| plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, |
| fontsize=12, verticalalignment='top', |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.6)) |
|
|
| plt.tight_layout() |
|
|
| |
| output_file = output_path / f"{Path(video_result['video_name']).stem}_{self.metric_type}_quality_{suffix}.png" |
| plt.savefig(output_file, dpi=300, bbox_inches='tight') |
| plt.close() |
|
|
| self.logger.info(f" Saved visualization: {output_file.name}") |
|
|
| def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results', |
| video_granularity: str = 'second') -> None: |
| """ |
| Create visualizations for all results (images and videos). |
| |
| Args: |
| results: List of prediction results |
| output_dir: Directory to save visualizations |
| video_granularity: For videos - 'frame', 'second', or 'both' |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| self.logger.info(f"\n🎨 Creating visualizations (granularity: {video_granularity})...") |
|
|
| for result in results: |
| if result['type'] == 'video': |
| self.visualize_video_results(result, output_dir, video_granularity) |
| elif result['type'] == 'image' and result.get('quality_score') is not None: |
| |
| img = result['original_image'].copy() |
| draw = ImageDraw.Draw(img) |
|
|
| score = result['quality_score'] |
| score_text = f"Quality: {score:.3f}" |
|
|
| |
| |
| norm_score = max(0, score) |
|
|
| if norm_score < 0.5: |
| r, g, b = 255, int(255 * norm_score * 2), 0 |
| else: |
| r, g, b = int(255 * (2 - norm_score * 2)), 255, 0 |
|
|
| color = (r, g, b) |
|
|
| box_coords = [10, 10, 260, 60] |
| draw.rectangle(box_coords, fill=color) |
|
|
| try: |
| font = ImageFont.truetype("arial.ttf", 24) |
| except: |
| font = ImageFont.load_default() |
|
|
| draw.text((15, 20), score_text, fill='black', font=font) |
|
|
| output_file = output_path / f"annotated_{result['image_name']}" |
| img.save(output_file) |
|
|
| self.logger.info(f"✓ Visualizations saved to: {output_dir}/") |
|
|
| def save_results(self, results: List[Dict], output_path: str = 'predictions.json') -> None: |
| """Save prediction results to JSON file.""" |
| |
| clean_results = [] |
| for r in results: |
| clean_r = {k: v for k, v in r.items() |
| if k not in ['original_image', 'frame']} |
|
|
| |
| if clean_r.get('type') == 'video' and 'frame_results' in clean_r: |
| clean_r['frame_results'] = [ |
| {k: v for k, v in fr.items() if k != 'frame'} |
| for fr in clean_r['frame_results'] |
| ] |
|
|
| clean_results.append(clean_r) |
|
|
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| with open(output_path, 'w') as f: |
| json.dump({ |
| 'metadata': { |
| 'task': self.task, |
| 'model': self.model_name, |
| 'metric_type': self.metric_type, |
| 'timestamp': datetime.now().isoformat(), |
| 'total_files': len(clean_results) |
| }, |
| 'predictions': clean_results |
| }, f, indent=2) |
|
|
| self.logger.info(f"💾 Results saved to: {output_path}") |
|
|
| def print_summary(self, results: List[Dict]) -> None: |
| """Print formatted summary of prediction results.""" |
| self.logger.info("\n" + "=" * 80) |
| self.logger.info("PREDICTION SUMMARY") |
| self.logger.info("=" * 80) |
|
|
| image_results = [r for r in results if r.get('type') == 'image'] |
| video_results = [r for r in results if r.get('type') == 'video'] |
|
|
| if image_results: |
| valid_images = [r for r in image_results if r.get('quality_score') is not None] |
| if valid_images: |
| scores = [r['quality_score'] for r in valid_images] |
| self.logger.info(f"\n📸 Image Analysis ({len(valid_images)} images)") |
| self.logger.info(f" Average quality: {np.mean(scores):.2f}") |
| self.logger.info(f" Score range: [{np.min(scores):.2f}, {np.max(scores):.2f}]") |
|
|
| if video_results: |
| self.logger.info(f"\n🎬 Video Analysis ({len(video_results)} videos)") |
| for vr in video_results: |
| self.logger.info(f"\n {vr['video_name']}:") |
| self.logger.info(f" Duration: {vr['metadata']['duration']:.1f}s") |
| self.logger.info(f" Frames analyzed: {vr['num_frames_analyzed']}") |
| self.logger.info(f" Unique seconds: {vr['num_seconds']}") |
| self.logger.info(f" Average quality (per-second): {vr['average_quality']:.2f}") |
| self.logger.info(f" Quality range: [{vr['min_quality']:.2f}, {vr['max_quality']:.2f}]") |
| self.logger.info(f" Variability (std): {vr['std_quality']:.2f}") |
|
|
| self.logger.info("\n" + "=" * 80 + "\n") |
|
|
|
|
| def main(): |
| """Command-line interface for MIQA inference.""" |
| parser = argparse.ArgumentParser( |
| description='MIQA: Machine-centric Image and Video Quality Assessment', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Analyze video with per-second visualization |
| python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity second |
| |
| # Analyze video with both frame and second visualizations |
| python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity both |
| |
| # Process directory with frame-level visualization |
| python video_analytics_inference.py --input ./assets/demo_video --task det --video-frames 120 --visualize --viz-granularity second --metric-type consistency |
| """ |
| ) |
|
|
| parser.add_argument('--input', type=str, required=True, |
| help='Path to input image/video or directory') |
| parser.add_argument('--task', type=str, required=True, |
| choices=['cls', 'det', 'ins'], |
| help='Task type') |
| parser.add_argument('--model', type=str, default='ra_miqa', |
| choices=['ra_miqa'], |
| help='Model architecture (RA-MIQA only; matches Hub registry)') |
| parser.add_argument('--metric-type', type=str, default='composite', |
| choices=['composite', 'consistency', 'accuracy'], |
| help='Training metric type') |
| parser.add_argument('--device', type=str, default=None, |
| choices=['cuda', 'cpu'], |
| help='Device to run on') |
| parser.add_argument('--video-frames', type=int, default=50, |
| help='Target number of frames to sample from videos') |
| parser.add_argument('--save-results', action='store_true', |
| help='Save prediction results to file') |
| parser.add_argument('--output-file', type=str, default='predictions.json', |
| help='Output file path') |
| parser.add_argument('--visualize', action='store_true', |
| help='Create visualizations') |
| parser.add_argument('--viz-dir', type=str, default='inference_results', |
| help='Directory for visualizations') |
| parser.add_argument('--viz-granularity', type=str, default='second', |
| choices=['frame', 'second', 'both'], |
| help='Visualization granularity for videos: frame-level, per-second, or both') |
| parser.add_argument('--no-progress', action='store_true', |
| help='Disable progress bar') |
|
|
| args = parser.parse_args() |
|
|
| try: |
| miqa = MIQAInference( |
| task=args.task, |
| model_name=args.model, |
| metric_type=args.metric_type, |
| device=args.device, |
| video_target_frames=args.video_frames |
| ) |
|
|
| results = miqa.predict(args.input, show_progress=not args.no_progress) |
| miqa.print_summary(results) |
|
|
| if args.save_results: |
| miqa.save_results(results, args.output_file) |
|
|
| if args.visualize: |
| miqa.visualize_results(results, args.viz_dir, video_granularity=args.viz_granularity) |
|
|
| except Exception as e: |
| print(f"\n❌ Error: {str(e)}", file=sys.stderr) |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|