| import os |
| import cv2 |
| import torch |
| import spaces |
| import imageio |
| import numpy as np |
| import gradio as gr |
| torch.jit.script = lambda f: f |
|
|
| import argparse |
| from utils.batch_inference import ( |
| BSRInferenceLoop, BIDInferenceLoop |
| ) |
|
|
| |
| |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| def get_example(task): |
| case = { |
| "dn": [ |
| ['examples/bus.mp4',], |
| ['examples/koala.mp4',], |
| ['examples/flamingo.mp4',], |
| ['examples/rhino.mp4',], |
| ['examples/elephant.mp4',], |
| ['examples/sheep.mp4',], |
| ['examples/dog-agility.mp4',], |
| |
| ], |
| "sr": [ |
| ['examples/bus_sr.mp4',], |
| ['examples/koala_sr.mp4',], |
| ['examples/flamingo_sr.mp4',], |
| ['examples/rhino_sr.mp4',], |
| ['examples/elephant_sr.mp4',], |
| ['examples/sheep_sr.mp4',], |
| ['examples/dog-agility_sr.mp4',], |
| |
| ] |
| |
| } |
| return case[task] |
|
|
|
|
|
|
| def update_prompt(input_video): |
| video_name = input_video.split('/')[-1] |
| return set_default_prompt(video_name) |
|
|
|
|
| |
| video_to_image = { |
| 'bus.mp4': ['examples_frames/bus'], |
| 'koala.mp4': ['examples_frames/koala'], |
| 'dog-gooses.mp4': ['examples_frames/dog-gooses'], |
| 'flamingo.mp4': ['examples_frames/flamingo'], |
| 'rhino.mp4': ['examples_frames/rhino'], |
| 'elephant.mp4': ['examples_frames/elephant'], |
| 'sheep.mp4': ['examples_frames/sheep'], |
| 'dog-agility.mp4': ['examples_frames/dog-agility'], |
|
|
| 'bus_sr.mp4': ['examples_frames/bus_sr'], |
| 'koala_sr.mp4': ['examples_frames/koala_sr'], |
| 'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], |
| 'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], |
| 'rhino_sr.mp4': ['examples_frames/rhino_sr'], |
| 'elephant_sr.mp4': ['examples_frames/elephant_sr'], |
| 'sheep_sr.mp4': ['examples_frames/sheep_sr'], |
| 'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], |
| } |
|
|
|
|
| def images_to_video(image_list, output_path, fps=10): |
| |
| frames = [np.array(img).astype(np.uint8) for img in image_list] |
| frames = frames[:20] |
|
|
| |
| writer = imageio.get_writer(output_path, fps=fps, codec='libx264') |
|
|
| for frame in frames: |
| writer.append_data(frame) |
|
|
| writer.close() |
|
|
| def video2frames(video_path): |
| |
| video = cv2.VideoCapture(video_path) |
|
|
| img_path = video_path[:-4] |
| |
| frame_count = 0 |
| os.makedirs(img_path, exist_ok=True) |
|
|
| while True: |
| |
| ret, frame = video.read() |
|
|
| |
| if not ret: |
| break |
|
|
| |
| frame_file = f"{img_path}/{frame_count:05}.jpg" |
| cv2.imwrite(frame_file, frame) |
|
|
| |
| frame_count += 1 |
|
|
| |
| video.release() |
|
|
| return img_path |
|
|
| @spaces.GPU(duration=120) |
| def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): |
|
|
| video_name = input_video.split('/')[-1] |
| if video_name in video_to_image: |
| frames_path = video_to_image[video_name][0] |
| else: |
| frames_path = video2frames(input_video) |
|
|
| print(f"[INFO] input_video: {input_video}") |
| print(f"[INFO] Frames path: {frames_path}") |
| args = argparse.Namespace() |
|
|
| |
| args.task = task |
| args.upscale = sr_ratio |
|
|
| |
| args.steps = n_steps |
| args.better_start = True |
| args.tiled = False |
| args.tile_size = 512 |
| args.tile_stride = 256 |
| args.pos_prompt = prompt |
| args.neg_prompt = n_prompt |
| args.cfg_scale = guidance_scale |
| |
| args.input = frames_path |
| args.n_samples = 1 |
| args.batch_size = 10 |
| args.final_size = (480, 854) |
| args.config = "configs/inference/my_cldm.yaml" |
| |
| args.guidance = False |
| args.g_loss = "w_mse" |
| args.g_scale = 0.0 |
| args.g_start = 1001 |
| args.g_stop = -1 |
| args.g_space = "latent" |
| args.g_repeat = 1 |
| |
| args.output = " " |
| |
| args.seed = seed |
| args.device = "cuda" |
|
|
| args.n_frames = n_frames |
| |
| args.warp_period = [0, 0.1] |
| args.merge_period = [0, 0] |
| args.ToMe_period = [0, 1] |
| args.merge_ratio = [0.6, 0] |
|
|
| if args.task == "sr": |
| restored_vid_path = BSRInferenceLoop(args).run() |
| elif args.task == "dn": |
| restored_vid_path = BIDInferenceLoop(args).run() |
| |
| torch.cuda.empty_cache() |
| return restored_vid_path |
|
|
| |
| |
| |
|
|
|
|
| intro = """ |
| <div style="text-align:center"> |
| <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> |
| DiffIR2VR - <small>Zero-Shot Video Restoration</small> |
| </h1> |
| <span>[<a target="_blank" href="https://jimmycv07.github.io/DiffIR2VR_web/">Project page</a>] [<a target="_blank" href="https://huggingface.co/papers/2406.06523">arXiv</a>]</span> |
| <div style="display:flex; justify-content: center;margin-top: 0.5em">Note that this page is a limited demo of DiffIR2VR. |
| For more configurations, please visit our GitHub page. The code will be released soon!</div> |
| <div style="display:flex; justify-content: center;margin-top: 0.5em; color: red;">For super-resolution, |
| it is recommended that the final frame size (original size * upscale ratio) be around 480x854, |
| else the demo may fail due to lengthy inference times.</div> |
| </div> |
| """ |
| |
|
|
| with gr.Blocks(css="style.css") as demo: |
|
|
| gr.HTML(intro) |
| |
|
|
| with gr.Tab(label="Super-resolution with DiffBIR"): |
| with gr.Row(): |
| input_video = gr.Video(label="Input Video") |
| output_video = gr.Video(label="Restored Video", interactive=False) |
|
|
| with gr.Row(): |
| run_button = gr.Button("Restore your video !", visible=True) |
|
|
| with gr.Accordion('Advanced options', open=False): |
| prompt = gr.Textbox( |
| label="Prompt", |
| max_lines=1, |
| placeholder="describe your video content" |
| |
| ) |
| sr_ratio = gr.Slider(label='Upscale ratio', |
| minimum=1, |
| maximum=16, |
| value=4, |
| step=0.5) |
| n_frames = gr.Slider(label='Frames', |
| minimum=1, |
| maximum=60, |
| value=10, |
| step=1) |
| n_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=100, |
| value=5, |
| step=1) |
| guidance_scale = gr.Slider(label='Guidance Scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=4.0, |
| step=0.1) |
| seed = gr.Slider(label='Seed', |
| minimum=-1, |
| maximum=1000, |
| step=1, |
| randomize=True) |
| n_prompt = gr.Textbox( |
| label='Negative Prompt', |
| value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" |
| ) |
| task = gr.Textbox(value="sr", visible=False) |
| |
| |
| |
| |
| |
| |
| run_button.click(fn = DiffBIR_restore, |
| inputs = [input_video, |
| prompt, |
| sr_ratio, |
| n_frames, |
| n_steps, |
| guidance_scale, |
| seed, |
| n_prompt, |
| task |
| ], |
| outputs = [output_video] |
| ) |
| gr.Examples( |
| examples=get_example("sr"), |
| label='Examples', |
| inputs=[input_video], |
| outputs=[output_video], |
| examples_per_page=7 |
| ) |
|
|
| with gr.Tab(label="Denoise with DiffBIR"): |
| with gr.Row(): |
| input_video = gr.Video(label="Input Video") |
| output_video = gr.Video(label="Restored Video", interactive=False) |
|
|
| with gr.Row(): |
| run_button = gr.Button("Restore your video !", visible=True) |
|
|
| with gr.Accordion('Advanced options', open=False): |
| prompt = gr.Textbox( |
| label="Prompt", |
| max_lines=1, |
| placeholder="describe your video content" |
| |
| ) |
| n_frames = gr.Slider(label='Frames', |
| minimum=1, |
| maximum=60, |
| value=10, |
| step=1) |
| n_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=100, |
| value=5, |
| step=1) |
| guidance_scale = gr.Slider(label='Guidance Scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=4.0, |
| step=0.1) |
| seed = gr.Slider(label='Seed', |
| minimum=-1, |
| maximum=1000, |
| step=1, |
| randomize=True) |
| n_prompt = gr.Textbox( |
| label='Negative Prompt', |
| value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" |
| ) |
| task = gr.Textbox(value="dn", visible=False) |
| sr_ratio = gr.Number(value=1, visible=False) |
| |
| |
| |
| |
| |
| |
| run_button.click(fn = DiffBIR_restore, |
| inputs = [input_video, |
| prompt, |
| sr_ratio, |
| n_frames, |
| n_steps, |
| guidance_scale, |
| seed, |
| n_prompt, |
| task |
| ], |
| outputs = [output_video] |
| ) |
| gr.Examples( |
| examples=get_example("dn"), |
| label='Examples', |
| inputs=[input_video], |
| outputs=[output_video], |
| examples_per_page=7 |
| ) |
|
|
| demo.queue() |
|
|
| demo.launch() |