miqa
xiaoqi-wang commited on
Commit
aeb2574
·
verified ·
1 Parent(s): a7301d6

Upload video_annotator_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. video_annotator_inference.py +464 -0
video_annotator_inference.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import argparse
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional, Tuple
8
+ from collections import OrderedDict
9
+ import json
10
+ from datetime import datetime
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import cv2 # OpenCV for video processing
14
+ import matplotlib.pyplot as plt # Matplotlib for plotting
15
+ import io
16
+
17
+ # Image processing imports
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ import torchvision.transforms as transforms
20
+
21
+ # Import your existing model components
22
+ # Ensure these files (models/, utils/) are in the same directory or accessible in PYTHONPATH
23
+ from models.MIQA_base import get_torch_model, get_timm_model
24
+ from models.RA_MIQA import RegionVisionTransformer
25
+ from models.hf_model_registry import HF_REPO_ID, HF_REVISION, MODEL_FILENAMES
26
+ from utils.hf_download_utils import ensure_checkpoint_from_hf
27
+ SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'}
28
+
29
+
30
+ class MIQAInference:
31
+ """
32
+ MODIFIED Inference wrapper for MIQA models.
33
+ Now includes a method to predict on PIL Image objects directly.
34
+ """
35
+
36
+ def __init__(self, task: str, model_name: str = 'ra_miqa',
37
+ metric_type: str = 'composite', device: Optional[str] = None):
38
+ self.task = task.lower()
39
+ self.model_name = model_name
40
+ self.metric_type = metric_type
41
+ self.logger = self._setup_logger()
42
+
43
+ if device is None:
44
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+ else:
46
+ self.device = torch.device(device)
47
+
48
+ self.logger.info(f"🚀 Initializing MIQA Inference System")
49
+ self.logger.info(f" Task: {self.task.upper()}")
50
+ self.logger.info(f" Model: {self.model_name}")
51
+ self.logger.info(f" Metric Type: {self.metric_type}")
52
+ self.logger.info(f" Device: {self.device}")
53
+
54
+ self._validate_config()
55
+ self.model = self._load_model()
56
+ self.transforms1, self.transforms2 = self._get_transforms()
57
+ self.logger.info("✅ System ready for inference\n")
58
+
59
+
60
+ def _setup_logger(self) -> logging.Logger:
61
+ """Configure logging with both file and console output."""
62
+ logger = logging.getLogger('MIQA_Inference')
63
+ logger.setLevel(logging.INFO)
64
+
65
+ if logger.hasHandlers():
66
+ return logger
67
+
68
+ logger.propagate = False
69
+
70
+ # Console handler with clean formatting
71
+ console_handler = logging.StreamHandler(sys.stdout)
72
+ console_handler.setLevel(logging.INFO)
73
+ console_formatter = logging.Formatter('%(message)s')
74
+ console_handler.setFormatter(console_formatter)
75
+ logger.addHandler(console_handler)
76
+
77
+ return logger
78
+
79
+ def _validate_config(self) -> None:
80
+ """Validate that the requested configuration is supported."""
81
+
82
+ if self.metric_type not in ['composite', 'consistency', 'accuracy']:
83
+ raise ValueError(
84
+ f"Invalid metric_type '{self.metric_type}'. "
85
+ f"Supported: ['composite', 'consistency', 'accuracy']"
86
+ )
87
+
88
+ if self.task not in MODEL_FILENAMES[self.metric_type]:
89
+ raise ValueError(
90
+ f"Invalid task '{self.task}'. "
91
+ f"Supported tasks: {list(MODEL_FILENAMES[self.metric_type].keys())}"
92
+ )
93
+
94
+ if self.model_name not in MODEL_FILENAMES[self.metric_type][self.task]:
95
+ available = list(MODEL_FILENAMES[self.metric_type][self.task].keys())
96
+ raise ValueError(
97
+ f"Model '{self.model_name}' not available for task '{self.task}'. "
98
+ f"Available models: {available}"
99
+ )
100
+
101
+ def _get_checkpoint_path(self) -> str:
102
+ """Generate the path where model checkpoint should be stored."""
103
+ base_dir = Path('models') / 'checkpoints' / f'{self.metric_type}_metric'
104
+ base_dir.mkdir(parents=True, exist_ok=True)
105
+
106
+ filename = MODEL_FILENAMES[self.metric_type][self.task][self.model_name]
107
+ return str(base_dir / filename)
108
+
109
+ def _download_weights(self, checkpoint_path: str) -> bool:
110
+ """
111
+ Download model weights if not present locally.
112
+
113
+ Returns:
114
+ True if weights are available (already existed or successfully downloaded)
115
+ """
116
+ if os.path.exists(checkpoint_path):
117
+ self.logger.info(f"✓ Found cached model weights")
118
+ return True
119
+
120
+ self.logger.info(
121
+ f"⏬ Downloading from Hugging Face: repo={HF_REPO_ID}, "
122
+ f"file={Path(checkpoint_path).name}, rev={HF_REVISION}"
123
+ )
124
+ try:
125
+ ensure_checkpoint_from_hf(
126
+ repo_id=HF_REPO_ID,
127
+ filename=Path(checkpoint_path).name,
128
+ local_dir=str(Path(checkpoint_path).parent),
129
+ revision=HF_REVISION,
130
+ )
131
+ self.logger.info("✓ Successfully downloaded model weights")
132
+ return True
133
+ except Exception as e:
134
+ self.logger.error(f"❌ Failed to download model weights from Hugging Face: {e}")
135
+ return False
136
+
137
+ def _create_model(self) -> torch.nn.Module:
138
+ """Create the model architecture."""
139
+ if self.model_name == 'ra_miqa':
140
+ self.logger.info("Building Region-Aware Vision Transformer...")
141
+ model = RegionVisionTransformer(
142
+ base_model_name='vit_small_patch16_224',
143
+ pretrained=False, # We'll load our trained weights
144
+ mmseg_config_path='models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
145
+ checkpoint_path='models/checkpoints/sere_finetuned_vit_small_ep100.pth'
146
+ )
147
+ else:
148
+ try:
149
+ self.logger.info(f"Building {self.model_name} from PyTorch...")
150
+ model = get_torch_model(model_name=self.model_name, pretrained=False, num_classes=1)
151
+ except Exception:
152
+ self.logger.info(f"Building {self.model_name} from timm library...")
153
+ model = get_timm_model(model_name=self.model_name, pretrained=False, num_classes=1)
154
+
155
+ return model
156
+
157
+ def _load_model(self) -> torch.nn.Module:
158
+ """Load model with pre-trained weights."""
159
+ checkpoint_path = self._get_checkpoint_path()
160
+
161
+ # Ensure weights are available
162
+ if not self._download_weights(checkpoint_path):
163
+ raise RuntimeError("Cannot proceed without model weights")
164
+
165
+ # Create model architecture
166
+ self.logger.info("🔧 Loading model...")
167
+ model = self._create_model()
168
+
169
+ # Load weights
170
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
171
+ state_dict = checkpoint.get('state_dict', checkpoint)
172
+
173
+ # Remove 'module.' prefix if present (from DataParallel training)
174
+ new_state_dict = OrderedDict()
175
+ for k, v in state_dict.items():
176
+ name = k.replace('module.', '') if k.startswith('module.') else k
177
+ new_state_dict[name] = v
178
+
179
+ model.load_state_dict(new_state_dict, strict=True)
180
+ model = model.to(self.device)
181
+ model.eval() # Set to evaluation mode
182
+
183
+ self.logger.info("✓ Model loaded successfully")
184
+
185
+ return model
186
+
187
+ def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose | None]:
188
+ """
189
+ Return preprocessing transforms based on model type.
190
+ """
191
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
192
+ IMAGENET_STD = (0.229, 0.224, 0.225)
193
+ SIMPLE_MEAN = (0.5, 0.5, 0.5)
194
+ SIMPLE_STD = (0.5, 0.5, 0.5)
195
+
196
+ # Default (for single-input backbones)
197
+ transform_imagenet = transforms.Compose([
198
+ transforms.Resize(288),
199
+ transforms.CenterCrop(size=224),
200
+ transforms.ToTensor(),
201
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
202
+ ])
203
+
204
+ transform_simple = transforms.Compose([
205
+ transforms.Resize(288),
206
+ transforms.CenterCrop(size=224),
207
+ transforms.ToTensor(),
208
+ transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD)
209
+ ])
210
+
211
+ # 1️⃣ CNNs(ResNet / EfficientNet)
212
+ if any(k in self.model_name for k in ['resnet', 'efficientnet']):
213
+ return transform_imagenet, None
214
+
215
+ # 2️⃣ ViT
216
+ elif 'vit' in self.model_name:
217
+ return transform_simple, None
218
+
219
+ # 3️⃣ ra_miqa
220
+ elif 'ra_miqa' in self.model_name:
221
+ transform_1 = transforms.Compose([
222
+ transforms.Resize(288),
223
+ transforms.CenterCrop(size=224),
224
+ transforms.ToTensor(),
225
+ transforms.Normalize(mean=SIMPLE_MEAN, std=SIMPLE_STD)
226
+ ])
227
+ transform_2 = transforms.Compose([
228
+ transforms.Resize(288),
229
+ transforms.CenterCrop((288, 288)),
230
+ transforms.ToTensor(),
231
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
232
+ ])
233
+ return transform_1, transform_2
234
+
235
+ # fallback
236
+ else:
237
+ print(f"[Warning] Unknown model type '{self.model_name}', using ImageNet normalization.")
238
+ return transform_imagenet, None
239
+
240
+ @torch.no_grad()
241
+ def predict_image_object(self, image: Image.Image) -> float:
242
+ """
243
+ NEW METHOD: Run inference on a PIL Image object.
244
+ """
245
+ # Preprocess the image
246
+ img1 = self.transforms1(image).unsqueeze(0).to(self.device)
247
+ img2 = self.transforms2(image).unsqueeze(0).to(self.device) if self.transforms2 else None
248
+
249
+ # Run inference based on model input requirements
250
+ if img2 is None:
251
+ output = self.model(img1)
252
+ else:
253
+ output = self.model(img1, img2)
254
+
255
+ score = output.item() if torch.is_tensor(output) else float(output)
256
+ return score
257
+
258
+
259
+ class VideoMIQAProcessor:
260
+ """
261
+ A wrapper to process videos using the MIQAInference engine and create
262
+ a visualized output video with scores and plots.
263
+ """
264
+ # --- Visualization Constants ---
265
+ PANEL_WIDTH = 480
266
+ FONT = cv2.FONT_HERSHEY_SIMPLEX
267
+ FONT_SCALE_L = 1.0
268
+ FONT_SCALE_M = 0.8
269
+ FONT_COLOR = (255, 255, 255) # White
270
+ LINE_THICKNESS = 2
271
+
272
+ # Plotting style
273
+ plt.style.use('dark_background')
274
+
275
+ def __init__(self, miqa_engine: MIQAInference):
276
+ self.miqa_engine = miqa_engine
277
+ self.logger = miqa_engine.logger
278
+
279
+ def _create_score_plot(self, scores: List[float], width: int, height: int) -> np.ndarray:
280
+ """
281
+ Creates a line chart of scores using Matplotlib and returns it as an OpenCV image.
282
+ """
283
+ fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
284
+ ax.plot(scores, color='#4287f5', linewidth=2)
285
+ ax.set_xlim(0, max(1, len(scores)))
286
+ ax.set_ylim(0, 1)
287
+ ax.set_title("Quality Score Fluctuation", fontsize=10)
288
+ ax.set_xlabel("Frame", fontsize=8)
289
+ ax.set_ylabel("Score", fontsize=8)
290
+ ax.grid(True, alpha=0.3)
291
+ fig.tight_layout(pad=1.5)
292
+
293
+ # Render plot to an in-memory buffer
294
+ buf = io.BytesIO()
295
+ fig.savefig(buf, format='png')
296
+ buf.seek(0)
297
+ plt.close(fig)
298
+
299
+ # Convert buffer to a PIL Image and then to an OpenCV image
300
+ plot_img_pil = Image.open(buf)
301
+ plot_img_np = np.array(plot_img_pil)
302
+ plot_img_bgr = cv2.cvtColor(plot_img_np, cv2.COLOR_RGBA2BGR)
303
+
304
+ return plot_img_bgr
305
+
306
+ def process_video(self, input_path: str, output_path: str):
307
+ """
308
+ Reads a video, analyzes each frame for quality, and writes an annotated output video.
309
+ """
310
+ self.logger.info(f"📹 Starting processing for: {Path(input_path).name}")
311
+ cap = cv2.VideoCapture(input_path)
312
+ if not cap.isOpened():
313
+ self.logger.error(f"❌ Failed to open video: {input_path}")
314
+ return
315
+
316
+ # Video properties
317
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
318
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
319
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
320
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
321
+
322
+ # New dimensions for output video (with side panel)
323
+ output_width = orig_width + self.PANEL_WIDTH
324
+ output_height = orig_height
325
+
326
+ # Setup video writer
327
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
328
+ out = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))
329
+
330
+ scores = []
331
+ progress_bar = tqdm(range(frame_count), desc="Analyzing frames", ncols=100)
332
+
333
+ for frame_idx in progress_bar:
334
+ ret, frame = cap.read()
335
+ if not ret:
336
+ break
337
+
338
+ # --- MIQA Inference ---
339
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
340
+ pil_image = Image.fromarray(frame_rgb)
341
+ score = self.miqa_engine.predict_image_object(pil_image)
342
+ scores.append(score)
343
+
344
+ # --- Visualization Panel ---
345
+ panel = np.zeros((orig_height, self.PANEL_WIDTH, 3), dtype=np.uint8)
346
+
347
+ # 1. Task Info
348
+ task_text = f"Task: {self.miqa_engine.task.upper()}"
349
+ cv2.putText(panel, task_text, (20, 50), self.FONT, self.FONT_SCALE_M, self.FONT_COLOR, self.LINE_THICKNESS)
350
+
351
+ # 2. Current Score
352
+ score_text = f"Quality Score: {score:.3f}"
353
+ # Color coding for score text
354
+ norm_score = max(0, score)
355
+ if norm_score < 0.5:
356
+ color = (0, int(255 * (norm_score * 2)), 255) # Red -> Yellow
357
+ else:
358
+ color = (0, 255, int(255 * (2 - norm_score * 2))) # Yellow -> Green
359
+ cv2.putText(panel, score_text, (20, 110), self.FONT, self.FONT_SCALE_L, color, self.LINE_THICKNESS + 1)
360
+
361
+ # 3. Frame Info
362
+ frame_text = f"Frame: {frame_idx + 1}/{frame_count}"
363
+ cv2.putText(panel, frame_text, (20, orig_height - 30), self.FONT, self.FONT_SCALE_M, self.FONT_COLOR, 1)
364
+
365
+ # 4. Score Plot
366
+ if len(scores) > 1:
367
+ plot_height = 300
368
+ plot_width = self.PANEL_WIDTH - 40 # with margins
369
+ plot_img = self._create_score_plot(scores, plot_width, plot_height)
370
+
371
+ # Position the plot on the panel
372
+ y_offset = 160
373
+ panel[y_offset:y_offset + plot_img.shape[0], 20:20 + plot_img.shape[1]] = plot_img
374
+
375
+ # --- Combine and Write Frame ---
376
+ combined_frame = np.concatenate((frame, panel), axis=1)
377
+ out.write(combined_frame)
378
+
379
+ # Release resources
380
+ cap.release()
381
+ out.release()
382
+ self.logger.info(f"✅ Finished processing. Annotated video saved to: {output_path}\n")
383
+
384
+
385
+ def main():
386
+ """Command-line interface for Video MIQA inference."""
387
+ parser = argparse.ArgumentParser(
388
+ description='MIQA for Video: Machine-centric Image Quality Assessment on Video Frames',
389
+ formatter_class=argparse.RawDescriptionHelpFormatter,
390
+ epilog="""
391
+ Examples:
392
+ # Analyze a single video and save the annotated output
393
+ python video_annotator_inference.py --input my_video.mp4 --task cls --model ra_miqa
394
+
395
+ # Analyze all videos in a directory
396
+ python video_annotator_inference.py --input ./video_folder/ --task det --model resnet50
397
+ """
398
+ )
399
+
400
+ parser.add_argument('--input', type=str, required=True,
401
+ help='Path to input video file or a directory containing videos.')
402
+ parser.add_argument('--task', type=str, required=True,
403
+ choices=['cls', 'det', 'ins'],
404
+ help='Task type: cls (classification), det (detection), ins (instance).')
405
+ parser.add_argument('--model', type=str, default='ra_miqa',
406
+ choices=['ra_miqa'],
407
+ help='Model architecture (default: ra_miqa; Hub weights are RA-MIQA only).')
408
+ parser.add_argument('--metric-type', type=str, default='composite',
409
+ choices=['composite', 'consistency', 'accuracy'],
410
+ help='Training metric type (default: composite).')
411
+ parser.add_argument('--device', type=str, default=None,
412
+ choices=['cuda', 'cpu'],
413
+ help='Device to run on (auto-detect if not specified).')
414
+ parser.add_argument('--output-dir', type=str, default='inference_results',
415
+ help='Directory to save the output annotated videos.')
416
+
417
+ args = parser.parse_args()
418
+
419
+ try:
420
+ # Initialize the core inference engine
421
+ miqa_engine = MIQAInference(
422
+ task=args.task,
423
+ model_name=args.model,
424
+ metric_type=args.metric_type,
425
+ device=args.device
426
+ )
427
+
428
+ # Initialize the video processor
429
+ video_processor = VideoMIQAProcessor(miqa_engine)
430
+
431
+ # Find videos to process
432
+ input_path = Path(args.input)
433
+ videos_to_process = []
434
+ if input_path.is_dir():
435
+ for ext in SUPPORTED_VIDEO_EXTENSIONS:
436
+ videos_to_process.extend(input_path.glob(f"*{ext}"))
437
+ elif input_path.is_file() and input_path.suffix.lower() in SUPPORTED_VIDEO_EXTENSIONS:
438
+ videos_to_process.append(input_path)
439
+
440
+ if not videos_to_process:
441
+ raise FileNotFoundError(f"No supported video files found in '{args.input}'")
442
+
443
+
444
+ # Create output directory
445
+ output_dir = Path(args.output_dir) / 'video' /args.task / args.metric_type
446
+ output_dir.mkdir(parents=True, exist_ok=True)
447
+
448
+ # Process each video
449
+ for video_path in videos_to_process:
450
+ output_filename = f"{video_path.stem}_miqa_{args.model}_{args.task}.mp4"
451
+ output_filepath = str(output_dir / output_filename)
452
+ video_processor.process_video(str(video_path), output_filepath)
453
+
454
+ except Exception as e:
455
+ # Use the logger if it exists, otherwise print
456
+ try:
457
+ miqa_engine.logger.error(f"\n❌ An error occurred: {str(e)}")
458
+ except:
459
+ print(f"\n❌ An error occurred: {str(e)}", file=sys.stderr)
460
+ sys.exit(1)
461
+
462
+
463
+ if __name__ == '__main__':
464
+ main()