import os import sys import torch import argparse import logging from pathlib import Path from typing import List, Dict, Optional, Tuple from collections import OrderedDict, defaultdict import json from datetime import datetime from tqdm import tqdm import numpy as np import cv2 # Image processing imports from PIL import Image, ImageDraw, ImageFont import torchvision.transforms as transforms import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Use non-interactive backend # Import your existing model components from models.MIQA_base import get_torch_model, get_timm_model from models.RA_MIQA import RegionVisionTransformer from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES from utils.hf_download_utils import ensure_checkpoint_from_hf # Supported file extensions SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.JPEG', '.png', '.bmp', '.tiff', '.tif'} SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} class VideoFrameExtractor: """ Extracts and samples frames from video files intelligently. This class handles different sampling strategies to balance between thoroughness and computational efficiency. """ def __init__(self, sampling_strategy: str = 'uniform', target_frames: int = 30, fps_sample: Optional[float] = None): """ Initialize the frame extractor. Args: sampling_strategy: How to sample frames - 'uniform', 'fps', or 'keyframe' target_frames: Target number of frames to extract (for uniform sampling) fps_sample: Sample rate in frames per second (for fps sampling) """ self.sampling_strategy = sampling_strategy self.target_frames = target_frames self.fps_sample = fps_sample def extract_frames(self, video_path: str) -> Tuple[List[np.ndarray], List[float], Dict]: """ Extract frames from video based on sampling strategy. Returns: Tuple of (frames_list, timestamps_list, video_metadata) """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Cannot open video file: {video_path}") # Get video properties total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) metadata = { 'total_frames': total_frames, 'fps': fps, 'duration': duration, 'width': width, 'height': height } # Determine which frames to sample frame_indices = self._get_sample_indices(total_frames, fps) frames = [] timestamps = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: # Convert BGR to RGB (OpenCV uses BGR) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame_rgb) # Calculate timestamp in seconds timestamp = idx / fps if fps > 0 else idx timestamps.append(timestamp) cap.release() return frames, timestamps, metadata def _get_sample_indices(self, total_frames: int, fps: float) -> List[int]: """ Determine which frame indices to sample based on strategy. """ if self.sampling_strategy == 'uniform': # Sample frames uniformly across the video if total_frames <= self.target_frames: return list(range(total_frames)) else: # Calculate step size to get approximately target_frames step = total_frames / self.target_frames indices = [int(i * step) for i in range(self.target_frames)] return indices elif self.sampling_strategy == 'fps': # Sample at a specific frame rate if self.fps_sample is None: raise ValueError("fps_sample must be specified for fps sampling strategy") frame_interval = max(1, int(fps / self.fps_sample)) indices = list(range(0, total_frames, frame_interval)) return indices else: raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}") def aggregate_scores_by_second(frame_results: List[Dict]) -> List[Dict]: """ Aggregate frame-level quality scores to per-second averages. This function groups all frames that fall within the same second and computes their average quality score. This provides a smoothed view of quality over time, reducing noise from frame-to-frame variations. Args: frame_results: List of dictionaries with 'timestamp' and 'quality_score' Returns: List of dictionaries with per-second aggregated scores """ # Group frames by their second (floor of timestamp) seconds_data = defaultdict(list) for frame in frame_results: second = int(frame['timestamp']) # Floor to nearest second seconds_data[second].append(frame['quality_score']) # Calculate average for each second per_second_results = [] for second in sorted(seconds_data.keys()): scores = seconds_data[second] per_second_results.append({ 'second': second, 'timestamp': float(second), # Use second as timestamp for plotting 'quality_score': np.mean(scores), 'min_score': np.min(scores), 'max_score': np.max(scores), 'num_frames': len(scores), 'std_score': np.std(scores) if len(scores) > 1 else 0.0 }) return per_second_results class MIQAInference: """ Inference wrapper for MIQA models supporting both images and videos. """ def __init__(self, task: str, model_name: str = 'ra_miqa', metric_type: str = 'composite', device: Optional[str] = None, video_sampling: str = 'uniform', video_target_frames: int = 30): """ Initialize the MIQA inference system. Args: task: Task type - 'cls', 'det', or 'ins' model_name: Model architecture to use metric_type: Training objective - 'composite', 'consistency', or 'accuracy' device: Device to run inference on video_sampling: Frame sampling strategy for videos video_target_frames: Target number of frames to extract from videos """ self.task = task.lower() self.model_name = model_name self.metric_type = metric_type self.video_target_frames = video_target_frames # Setup logging self.logger = self._setup_logger() # Determine device if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) self.logger.info(f"šŸš€ Initializing MIQA Inference System") self.logger.info(f" Task: {self.task.upper()}") self.logger.info(f" Model: {self.model_name}") self.logger.info(f" Device: {self.device}") # Validate configuration self._validate_config() # Initialize model self.model = self._load_model() # Setup image preprocessing self.transforms1, self.transforms2 = self._get_transforms() # Initialize video frame extractor self.frame_extractor = VideoFrameExtractor( sampling_strategy=video_sampling, target_frames=video_target_frames ) self.logger.info("āœ… System ready for inference\n") def _setup_logger(self) -> logging.Logger: """Configure logging with both file and console output.""" logger = logging.getLogger('MIQA_Inference') logger.setLevel(logging.INFO) if logger.hasHandlers(): return logger logger.propagate = False # Console handler with clean formatting console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_formatter = logging.Formatter('%(message)s') console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) return logger def _validate_config(self) -> None: """Validate that the requested configuration is supported.""" if self.metric_type not in ['composite', 'consistency', 'accuracy']: raise ValueError( f"Invalid metric_type '{self.metric_type}'. " f"Supported: ['composite', 'consistency', 'accuracy']" ) if self.task not in MODEL_FILENAMES[self.metric_type]: raise ValueError( f"Invalid task '{self.task}'. " f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}" ) if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]: available = list(MODEL_FILENAMES[self.metric_type][self.task].keys()) raise ValueError( f"Model '{self.model_name}' not available for task '{self.task}'. " f"Available models: {available}" ) def _get_checkpoint_path(self) -> str: """Generate the path where model checkpoint should be stored.""" base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric' base_dir.mkdir(parents=True, exist_ok=True) filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name] return str(base_dir / filename) def _download_weights(self, checkpoint_path: str) -> bool: """ Download model weights if not present locally. Returns: True if weights are available (already existed or successfully downloaded) """ if os.path.exists(checkpoint_path): self.logger.info(f"āœ“ Found cached model weights") return True self.logger.info( f"ā¬ Downloading from Hugging Face: repo={HF_REPO_ID}, " f"file={Path(checkpoint_path).name}, rev={HF_REVISION}" ) try: ensure_checkpoint_from_hf( repo_id=HF_REPO_ID, filename=Path(checkpoint_path).name, local_dir=str(Path(checkpoint_path).parent), revision=HF_REVISION, ) self.logger.info("āœ“ Successfully downloaded model weights") return True except Exception as e: self.logger.error(f"āŒ Failed to download model weights from Hugging Face: {e}") return False def _create_model(self) -> torch.nn.Module: """Create the model architecture.""" if self.model_name == 'ra_miqa': self.logger.info("Building Region-Aware Vision Transformer...") model = RegionVisionTransformer( base_model_name='vit_small_patch16_224', pretrained=False, # We'll load our trained weights mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py', checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth' ) else: try: self.logger.info(f"Building {self.model_name} from PyTorch...") model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1) except Exception: self.logger.info(f"Building {self.model_name} from timm library...") model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1) return model def _load_model(self) -> torch.nn.Module: """Load model with weights.""" checkpoint_path = self._get_checkpoint_path() if not self._download_weights(checkpoint_path): raise RuntimeError("Cannot proceed without model weights") self.logger.info("šŸ”§ Loading model...") model = self._create_model() checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint.get('state_dict', checkpoint) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=True) model = model.to(self.device) model.eval() self.logger.info("āœ“ Model loaded successfully") return model def _get_transforms(self) -> [transforms.Compose, transforms.Compose]: """ Get image preprocessing transforms. These transforms normalize images to match the training distribution. """ IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) SIMPLE_MEAN = (0.5, 0.5, 0.5) SIMPLE_STD = (0.5, 0.5, 0.5) transforms_list1 = [ transforms.Resize(288), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD) ] transform_list_2 = [ transforms.Resize(288), transforms.CenterCrop((288, 288)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ] return transforms.Compose(transforms_list1), transforms.Compose(transform_list_2) def _prepare_frame(self, frame: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: """ Preprocess a video frame for model input. Args: frame: Numpy array in RGB format Returns: Tuple of (cropped_tensor, resized_tensor) """ # Convert numpy array to PIL Image img = Image.fromarray(frame) # Apply transforms img1 = self.transforms1(img).unsqueeze(0) img2 = self.transforms2(img).unsqueeze(0) return img1, img2 def _prepare_image(self, image_path: str) -> Tuple[torch.Tensor, torch.Tensor, Image.Image]: """Load and preprocess an image file.""" img = Image.open(image_path).convert('RGB') img1 = self.transforms1(img).unsqueeze(0) img2 = self.transforms2(img).unsqueeze(0) return img1, img2, img @torch.no_grad() def predict_single_image(self, image_path: str) -> Dict: """Run inference on a single image.""" img_cropped, img_resized, original_img = self._prepare_image(image_path) img_cropped = img_cropped.to(self.device) img_resized = img_resized.to(self.device) output = self.model(img_cropped, img_resized) score = output.item() return { 'image_path': image_path, 'image_name': Path(image_path).name, 'quality_score': score, 'original_image': original_img, 'type': 'image' } @torch.no_grad() def predict_video(self, video_path: str, show_progress: bool = True) -> Dict: """ Run inference on a video file. This extracts frames, predicts quality for each, aggregates to per-second averages, and returns comprehensive time-series data suitable for visualization. """ self.logger.info(f"šŸŽ¬ Processing video: {Path(video_path).name}") # Extract frames frames, timestamps, metadata = self.frame_extractor.extract_frames(video_path) self.logger.info(f" Extracted {len(frames)} frames from {metadata['duration']:.1f}s video") # Process each frame frame_results = [] iterator = tqdm(frames, desc="Analyzing frames", disable=not show_progress, ncols=80) for frame, timestamp in zip(iterator, timestamps): img_cropped, img_resized = self._prepare_frame(frame) img_cropped = img_cropped.to(self.device) img_resized = img_resized.to(self.device) output = self.model(img_cropped, img_resized) score = output.item() frame_results.append({ 'timestamp': timestamp, 'quality_score': score, 'frame': frame # Store for visualization if needed }) # Aggregate frame results by second per_second_results = aggregate_scores_by_second(frame_results) # Calculate statistics from per-second data for better representation second_scores = [r['quality_score'] for r in per_second_results] return { 'video_path': video_path, 'video_name': Path(video_path).name, 'type': 'video', 'metadata': metadata, 'frame_results': frame_results, 'per_second_results': per_second_results, # NEW: Added per-second aggregation 'num_frames_analyzed': len(frame_results), 'num_seconds': len(per_second_results), # NEW: Number of unique seconds 'average_quality': np.mean(second_scores), 'min_quality': np.min(second_scores), 'max_quality': np.max(second_scores), 'std_quality': np.std(second_scores) } def predict(self, input_path: str, show_progress: bool = True) -> List[Dict]: """ Main prediction interface - handles images, videos, and directories. """ input_path = Path(input_path) # Handle single file if input_path.is_file(): ext = input_path.suffix.lower() if ext in SUPPORTED_IMAGE_EXTENSIONS: return [self.predict_single_image(str(input_path))] elif ext in SUPPORTED_VIDEO_EXTENSIONS: return [self.predict_video(str(input_path), show_progress)] else: raise ValueError(f"Unsupported file extension: {ext}") # Handle directory elif input_path.is_dir(): results = [] # Find all supported files image_paths = [] video_paths = [] for ext in SUPPORTED_IMAGE_EXTENSIONS: image_paths.extend(input_path.glob(f"*{ext}")) for ext in SUPPORTED_VIDEO_EXTENSIONS: video_paths.extend(input_path.glob(f"*{ext}")) image_paths = sorted([str(p) for p in image_paths]) video_paths = sorted([str(p) for p in video_paths]) if not image_paths and not video_paths: raise ValueError(f"No supported files found in {input_path}") self.logger.info(f"šŸ“ Found {len(image_paths)} images and {len(video_paths)} videos") # Process images if image_paths: for img_path in tqdm(image_paths, desc="Processing images", ncols=80): try: result = self.predict_single_image(img_path) results.append(result) except Exception as e: self.logger.warning(f"āš ļø Failed: {img_path}") # Process videos if video_paths: for vid_path in video_paths: try: result = self.predict_video(vid_path, show_progress) results.append(result) except Exception as e: self.logger.warning(f"āš ļø Failed: {vid_path}") return results else: raise ValueError(f"Input path does not exist: {input_path}") def visualize_video_results(self, video_result: Dict, output_dir: str = 'inference_results', granularity: str = 'second') -> None: """ Create time-series visualization for video quality predictions. This generates a line plot showing how quality varies across the video timeline. You can choose between frame-level and second-level granularity. Args: video_result: Dictionary containing video analysis results output_dir: Directory to save visualizations granularity: Visualization granularity - 'frame' for frame-by-frame, 'second' for per-second averages, or 'both' for dual plot """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) if granularity == 'both': # Create side-by-side comparison fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6)) # Frame-level plot (left) frame_results = video_result['frame_results'] frame_timestamps = [r['timestamp'] for r in frame_results] frame_scores = [r['quality_score'] for r in frame_results] ax1.plot(frame_timestamps, frame_scores, linewidth=1.5, color='#2E86AB', marker='o', markersize=3, markerfacecolor='white', markeredgewidth=1, alpha=0.7, label='Frame-level') ax1.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold') ax1.set_ylabel('Quality Score', fontsize=11, fontweight='bold') ax1.set_title('Frame-Level Quality', fontsize=12, fontweight='bold') ax1.grid(True, alpha=0.3, linestyle='--') ax1.set_ylim(0, 1) # Second-level plot (right) second_results = video_result['per_second_results'] second_timestamps = [r['timestamp'] for r in second_results] second_scores = [r['quality_score'] for r in second_results] ax2.plot(second_timestamps, second_scores, linewidth=2.5, color='#A23B72', marker='s', markersize=6, markerfacecolor='white', markeredgewidth=1.5, label='Per-second average') # Add error bars showing variability within each second if second_results and 'std_score' in second_results[0]: stds = [r['std_score'] for r in second_results] ax2.fill_between(second_timestamps, np.array(second_scores) - np.array(stds), np.array(second_scores) + np.array(stds), alpha=0.2, color='#A23B72') ax2.set_xlabel('Time (seconds)', fontsize=11, fontweight='bold') ax2.set_ylabel('Quality Score', fontsize=11, fontweight='bold') ax2.set_title('Per-Second Averaged Quality', fontsize=12, fontweight='bold') ax2.grid(True, alpha=0.3, linestyle='--') ax2.set_ylim(0, 1) # Add overall average line to both avg_score = video_result['average_quality'] ax1.axhline(y=avg_score, color='#F18F01', linestyle='--', linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}') ax2.axhline(y=avg_score, color='#F18F01', linestyle='--', linewidth=1.5, alpha=0.7, label=f'Overall avg: {avg_score:.2f}') ax1.legend(loc='best', framealpha=0.9) ax2.legend(loc='best', framealpha=0.9) plt.suptitle(f"Video Quality Analysis: {video_result['video_name']}, {self.task}-oriented MIQA", fontsize=14, fontweight='bold', y=1.02) suffix = 'comparison' else: # Single plot based on selected granularity plt.figure(figsize=(14, 6)) if granularity == 'frame': frame_results = video_result['frame_results'] timestamps = [r['timestamp'] for r in frame_results] scores = [r['quality_score'] for r in frame_results] plot_color = '#2E86AB' plot_label = 'Frame-level quality' title_suffix = '(Frame-Level)' suffix = 'frame' marker_size = 4 else: # second second_results = video_result['per_second_results'] timestamps = [r['timestamp'] for r in second_results] scores = [r['quality_score'] for r in second_results] plot_color = '#A23B72' plot_label = 'Per-second average' title_suffix = '(Per-Second Average)' suffix = 'second' marker_size = 6 # Main quality plot plt.plot(timestamps, scores, linewidth=2, color=plot_color, marker='o', markersize=marker_size, markerfacecolor='white', markeredgewidth=1.5, label=plot_label) # Add shaded region for second-level showing variability if granularity == 'second' and second_results and 'std_score' in second_results[0]: stds = [r['std_score'] for r in second_results] plt.fill_between(timestamps, np.array(scores) - np.array(stds), np.array(scores) + np.array(stds), alpha=0.2, color=plot_color) # Add average line avg_score = video_result['average_quality'] plt.axhline(y=avg_score, color='#F18F01', linestyle='--', linewidth=1.5, label=f'Average: {avg_score:.2f}') # Styling plt.xlabel('Time (seconds)', fontsize=12, fontweight='bold') plt.ylabel('Quality Score', fontsize=12, fontweight='bold') plt.title(f"Video Quality Analysis: {video_result['video_name']} {title_suffix}", fontsize=14, fontweight='bold', pad=20) plt.grid(True, alpha=0.3, linestyle='--') plt.legend(loc='best', framealpha=0.9) # Add statistics box if granularity == 'both': stats_text = ( f"Duration: {video_result['metadata']['duration']:.1f}s\n" f"Frames: {video_result['num_frames_analyzed']} | " f"Seconds: {video_result['num_seconds']}\n" f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n" f"Std Dev: {video_result['std_quality']:.2f}" ) # Add to the right subplot ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes, fontsize=9, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) else: stats_text = ( f"Duration: {video_result['metadata']['duration']:.1f}s\n" f"Frames Analyzed: {video_result['num_frames_analyzed']}\n" f"Unique Seconds: {video_result['num_seconds']}\n" f"Score Range: [{video_result['min_quality']:.2f}, {video_result['max_quality']:.2f}]\n" f"Std Dev: {video_result['std_quality']:.2f}" ) plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.6)) plt.tight_layout() # Save figure output_file = output_path / f"{Path(video_result['video_name']).stem}_{self.metric_type}_quality_{suffix}.png" plt.savefig(output_file, dpi=300, bbox_inches='tight') plt.close() self.logger.info(f" Saved visualization: {output_file.name}") def visualize_results(self, results: List[Dict], output_dir: str = 'inference_results', video_granularity: str = 'second') -> None: """ Create visualizations for all results (images and videos). Args: results: List of prediction results output_dir: Directory to save visualizations video_granularity: For videos - 'frame', 'second', or 'both' """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) self.logger.info(f"\nšŸŽØ Creating visualizations (granularity: {video_granularity})...") for result in results: if result['type'] == 'video': self.visualize_video_results(result, output_dir, video_granularity) elif result['type'] == 'image' and result.get('quality_score') is not None: # Use original image visualization logic img = result['original_image'].copy() draw = ImageDraw.Draw(img) score = result['quality_score'] score_text = f"Quality: {score:.3f}" # Simple color coding (adjust range as needed) # norm_score = score / 100.0 norm_score = max(0, score) if norm_score < 0.5: r, g, b = 255, int(255 * norm_score * 2), 0 else: r, g, b = int(255 * (2 - norm_score * 2)), 255, 0 color = (r, g, b) box_coords = [10, 10, 260, 60] draw.rectangle(box_coords, fill=color) try: font = ImageFont.truetype("arial.ttf", 24) except: font = ImageFont.load_default() draw.text((15, 20), score_text, fill='black', font=font) output_file = output_path / f"annotated_{result['image_name']}" img.save(output_file) self.logger.info(f"āœ“ Visualizations saved to: {output_dir}/") def save_results(self, results: List[Dict], output_path: str = 'predictions.json') -> None: """Save prediction results to JSON file.""" # Clean results for JSON serialization clean_results = [] for r in results: clean_r = {k: v for k, v in r.items() if k not in ['original_image', 'frame']} # For video results, remove frame data but keep scores if clean_r.get('type') == 'video' and 'frame_results' in clean_r: clean_r['frame_results'] = [ {k: v for k, v in fr.items() if k != 'frame'} for fr in clean_r['frame_results'] ] clean_results.append(clean_r) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: json.dump({ 'metadata': { 'task': self.task, 'model': self.model_name, 'metric_type': self.metric_type, 'timestamp': datetime.now().isoformat(), 'total_files': len(clean_results) }, 'predictions': clean_results }, f, indent=2) self.logger.info(f"šŸ’¾ Results saved to: {output_path}") def print_summary(self, results: List[Dict]) -> None: """Print formatted summary of prediction results.""" self.logger.info("\n" + "=" * 80) self.logger.info("PREDICTION SUMMARY") self.logger.info("=" * 80) image_results = [r for r in results if r.get('type') == 'image'] video_results = [r for r in results if r.get('type') == 'video'] if image_results: valid_images = [r for r in image_results if r.get('quality_score') is not None] if valid_images: scores = [r['quality_score'] for r in valid_images] self.logger.info(f"\nšŸ“ø Image Analysis ({len(valid_images)} images)") self.logger.info(f" Average quality: {np.mean(scores):.2f}") self.logger.info(f" Score range: [{np.min(scores):.2f}, {np.max(scores):.2f}]") if video_results: self.logger.info(f"\nšŸŽ¬ Video Analysis ({len(video_results)} videos)") for vr in video_results: self.logger.info(f"\n {vr['video_name']}:") self.logger.info(f" Duration: {vr['metadata']['duration']:.1f}s") self.logger.info(f" Frames analyzed: {vr['num_frames_analyzed']}") self.logger.info(f" Unique seconds: {vr['num_seconds']}") self.logger.info(f" Average quality (per-second): {vr['average_quality']:.2f}") self.logger.info(f" Quality range: [{vr['min_quality']:.2f}, {vr['max_quality']:.2f}]") self.logger.info(f" Variability (std): {vr['std_quality']:.2f}") self.logger.info("\n" + "=" * 80 + "\n") def main(): """Command-line interface for MIQA inference.""" parser = argparse.ArgumentParser( description='MIQA: Machine-centric Image and Video Quality Assessment', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Analyze video with per-second visualization python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity second # Analyze video with both frame and second visualizations python video_analytics_inference.py --input video.mp4 --task cls --visualize --viz-granularity both # Process directory with frame-level visualization python video_analytics_inference.py --input ./assets/demo_video --task det --video-frames 120 --visualize --viz-granularity second --metric-type consistency """ ) parser.add_argument('--input', type=str, required=True, help='Path to input image/video or directory') parser.add_argument('--task', type=str, required=True, choices=['cls', 'det', 'ins'], help='Task type') parser.add_argument('--model', type=str, default='ra_miqa', choices=['ra_miqa'], help='Model architecture (RA-MIQA only; matches Hub registry)') parser.add_argument('--metric-type', type=str, default='composite', choices=['composite', 'consistency', 'accuracy'], help='Training metric type') parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='Device to run on') parser.add_argument('--video-frames', type=int, default=50, help='Target number of frames to sample from videos') parser.add_argument('--save-results', action='store_true', help='Save prediction results to file') parser.add_argument('--output-file', type=str, default='predictions.json', help='Output file path') parser.add_argument('--visualize', action='store_true', help='Create visualizations') parser.add_argument('--viz-dir', type=str, default='inference_results', help='Directory for visualizations') parser.add_argument('--viz-granularity', type=str, default='second', choices=['frame', 'second', 'both'], help='Visualization granularity for videos: frame-level, per-second, or both') parser.add_argument('--no-progress', action='store_true', help='Disable progress bar') args = parser.parse_args() try: miqa = MIQAInference( task=args.task, model_name=args.model, metric_type=args.metric_type, device=args.device, video_target_frames=args.video_frames ) results = miqa.predict(args.input, show_progress=not args.no_progress) miqa.print_summary(results) if args.save_results: miqa.save_results(results, args.output_file) if args.visualize: miqa.visualize_results(results, args.viz_dir, video_granularity=args.viz_granularity) except Exception as e: print(f"\nāŒ Error: {str(e)}", file=sys.stderr) import traceback traceback.print_exc() sys.exit(1) if __name__ == '__main__': main()