Nekochu commited on
Commit
f4a7288
·
1 Parent(s): 18d73cb

postprocess at model res, defer resize+write to CPU (saves ~35s GPU)

Browse files

GPU phase: postprocess (clean_matte + despill) at 1024 instead of 4K (8x fewer
pixels), save uint8 to /tmp (~4MB/frame). No 4K file encoding during GPU time.

After GPU release (free CPU): load 1024, resize to output res (LANCZOS4),
write all 4 outputs (comp+fg+matte+processed).

Before: 28s inference@4K + 26s write@4K = 62s GPU
Target: 16s inference@1024 + 1s save = ~22s GPU

Files changed (1) hide show
  1. app.py +42 -27
app.py CHANGED
@@ -539,20 +539,18 @@ def corridorkey_batch_pytorch(model, images_f32, masks_f32, img_size,
539
  out = model(inp)
540
  del inp
541
 
542
- # --- GPU Postprocessing (despill + clean_matte + resize stay on device) ---
 
543
  alpha = out["alpha"].float()
544
  fg = out["fg"].float()
545
 
546
- alpha = TF.resize(alpha, [h, w])
547
- fg = TF.resize(fg, [h, w])
548
-
549
  if auto_despeckle:
550
  alpha = clean_matte_torch(alpha, area_threshold=int(despeckle_size), dilation=25, blur_size=5)
551
  fg = despill_torch(fg, despill_strength, screen_channel=screen_channel)
552
 
553
- # --- Single CPU transfer at the end ---
554
- alpha_np = alpha.cpu().numpy()
555
- fg_np = fg.cpu().numpy()
556
  del alpha, fg
557
 
558
  results = []
@@ -791,29 +789,19 @@ def _gpu_phase(video_path, resolution, despill_val, mask_mode,
791
  logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)",
792
  len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1))
793
 
794
- from concurrent.futures import ThreadPoolExecutor
795
- bg_lin = srgb_to_linear(create_checkerboard(w, h))
796
- comp_dir = os.path.join(tmpdir, "Comp")
797
- matte_dir = os.path.join(tmpdir, "Matte")
798
- fg_dir = os.path.join(tmpdir, "FG")
799
- processed_dir = os.path.join(tmpdir, "Processed")
800
- for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
801
- os.makedirs(d, exist_ok=True)
802
-
803
- t_write = time.time()
804
- progress(0.86, desc="Writing preview frames...")
805
- with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool:
806
- futs = [pool.submit(_write_frame, idx, alpha, fg, w, h, bg_lin,
807
- comp_dir, fg_dir, matte_dir, processed_dir)
808
- for idx, alpha, fg in all_results]
809
- for f in futs:
810
- f.result()
811
  del all_results
812
  gc.collect()
813
- logger.info("[GPU phase] Fast write in %.1fs", time.time() - t_write)
814
 
815
  return {
816
- "results": "written", "frame_times": frame_times,
817
  "use_gpu": True, "batch_size": batch_size,
818
  "w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
819
  "screen_color": screen_color,
@@ -938,8 +926,35 @@ def process_video(video_path, resolution, despill_val, mask_mode,
938
  fg_dir = os.path.join(tmpdir, "FG")
939
  matte_dir = os.path.join(tmpdir, "Matte")
940
  processed_dir = os.path.join(tmpdir, "Processed")
 
 
941
 
942
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
 
944
  # Phase 3: stitch videos from written frames
945
  logger.info("[Phase 3] Stitching videos")
@@ -970,7 +985,7 @@ def process_video(video_path, resolution, despill_val, mask_mode,
970
  status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | "
971
  f"{avg:.2f}s/frame | {engine}" +
972
  (f" batch={batch_size}" if use_gpu else "") +
973
- f" | {t_cpu:.0f}s CPU + {t_gpu:.0f}s GPU = {wall:.0f}s total" +
974
  (f" | {sc} screen" if sc != "green" else ""))
975
 
976
  return (
 
539
  out = model(inp)
540
  del inp
541
 
542
+ # --- GPU Postprocessing at MODEL resolution (1024/2048, NOT output 4K) ---
543
+ # Resize to output happens on CPU after GPU release (free time)
544
  alpha = out["alpha"].float()
545
  fg = out["fg"].float()
546
 
 
 
 
547
  if auto_despeckle:
548
  alpha = clean_matte_torch(alpha, area_threshold=int(despeckle_size), dilation=25, blur_size=5)
549
  fg = despill_torch(fg, despill_strength, screen_channel=screen_channel)
550
 
551
+ # Transfer at model resolution (1024×1024 = 4MB/frame, not 4K = 33MB/frame)
552
+ alpha_np = (alpha.clamp(0, 1) * 255).byte().cpu().numpy()
553
+ fg_np = (fg.clamp(0, 1) * 255).byte().cpu().numpy()
554
  del alpha, fg
555
 
556
  results = []
 
789
  logger.info("[GPU phase] done: %d frames in %.1fs (%.2fs/fr)",
790
  len(all_results), gpu_elapsed, gpu_elapsed / max(len(all_results), 1))
791
 
792
+ # Save model-resolution uint8 results to /tmp (tiny: ~4MB/frame at 1024)
793
+ raw_dir = os.path.join(tmpdir, "raw")
794
+ os.makedirs(raw_dir, exist_ok=True)
795
+ t_save = time.time()
796
+ for idx, alpha, fg in all_results:
797
+ np.save(os.path.join(raw_dir, f"alpha_{idx:05d}.npy"), alpha)
798
+ np.save(os.path.join(raw_dir, f"fg_{idx:05d}.npy"), fg)
 
 
 
 
 
 
 
 
 
 
799
  del all_results
800
  gc.collect()
801
+ logger.info("[GPU phase] Raw save in %.1fs", time.time() - t_save)
802
 
803
  return {
804
+ "results": "raw", "raw_dir": raw_dir, "frame_times": frame_times,
805
  "use_gpu": True, "batch_size": batch_size,
806
  "w": w, "h": h, "fps": fps, "tmpdir": tmpdir,
807
  "screen_color": screen_color,
 
926
  fg_dir = os.path.join(tmpdir, "FG")
927
  matte_dir = os.path.join(tmpdir, "Matte")
928
  processed_dir = os.path.join(tmpdir, "Processed")
929
+ for d in [comp_dir, fg_dir, matte_dir, processed_dir]:
930
+ os.makedirs(d, exist_ok=True)
931
 
932
  try:
933
+ # Phase 2: CPU resize + write (GPU results saved at model resolution)
934
+ raw_dir = data.get("raw_dir")
935
+ if raw_dir and use_gpu:
936
+ from concurrent.futures import ThreadPoolExecutor
937
+ t_phase2 = time.time()
938
+ bg_lin = srgb_to_linear(create_checkerboard(w, h))
939
+ n_frames = len(frame_times)
940
+ logger.info("[Phase 2] CPU resize %d→%dx%d + write (%d frames)",
941
+ int(resolution), w, h, n_frames)
942
+ progress(0.85, desc=f"Resizing to {w}x{h} + writing...")
943
+
944
+ def _resize_and_write(idx):
945
+ alpha_1k = np.load(os.path.join(raw_dir, f"alpha_{idx:05d}.npy"))
946
+ fg_1k = np.load(os.path.join(raw_dir, f"fg_{idx:05d}.npy"))
947
+ alpha = cv2.resize(alpha_1k, (w, h), interpolation=cv2.INTER_LANCZOS4)
948
+ fg = cv2.resize(fg_1k, (w, h), interpolation=cv2.INTER_LANCZOS4)
949
+ alpha = alpha.astype(np.float32) / 255.0
950
+ fg = fg.astype(np.float32) / 255.0
951
+ if alpha.ndim == 2:
952
+ alpha = alpha[:, :, np.newaxis]
953
+ _write_frame(idx, alpha, fg, w, h, bg_lin, comp_dir, fg_dir, matte_dir, processed_dir)
954
+
955
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as pool:
956
+ list(pool.map(_resize_and_write, range(n_frames)))
957
+ logger.info("[Phase 2] CPU write in %.1fs", time.time() - t_phase2)
958
 
959
  # Phase 3: stitch videos from written frames
960
  logger.info("[Phase 3] Stitching videos")
 
985
  status = (f"Processed {n} frames ({w}x{h}) at {resolution}px | "
986
  f"{avg:.2f}s/frame | {engine}" +
987
  (f" batch={batch_size}" if use_gpu else "") +
988
+ f" | {t_gpu:.0f}s GPU, {wall:.0f}s total" +
989
  (f" | {sc} screen" if sc != "green" else ""))
990
 
991
  return (