| | from argparse import ArgumentParser, Namespace |
| |
|
| | import torch |
| |
|
| | from accelerate.utils import set_seed |
| | from utils.inference import ( |
| | V1InferenceLoop, |
| | BSRInferenceLoop, BFRInferenceLoop, BIDInferenceLoop, UnAlignedBFRInferenceLoop |
| | ) |
| |
|
| |
|
| | def check_device(device: str) -> str: |
| | if device == "cuda": |
| | if not torch.cuda.is_available(): |
| | print("CUDA not available because the current PyTorch install was not " |
| | "built with CUDA enabled.") |
| | device = "cpu" |
| | else: |
| | if device == "mps": |
| | if not torch.backends.mps.is_available(): |
| | if not torch.backends.mps.is_built(): |
| | print("MPS not available because the current PyTorch install was not " |
| | "built with MPS enabled.") |
| | device = "cpu" |
| | else: |
| | print("MPS not available because the current MacOS version is not 12.3+ " |
| | "and/or you do not have an MPS-enabled device on this machine.") |
| | device = "cpu" |
| | print(f"using device {device}") |
| | return device |
| |
|
| |
|
| | def parse_args() -> Namespace: |
| | parser = ArgumentParser() |
| | |
| | parser.add_argument("--task", type=str, required=True, choices=["sr", "dn", "fr", "fr_bg"]) |
| | parser.add_argument("--upscale", type=float, required=True) |
| | parser.add_argument("--version", type=str, default="v2", choices=["v1", "v2"]) |
| | |
| | parser.add_argument("--steps", type=int, default=50) |
| | parser.add_argument("--better_start", action="store_true") |
| | parser.add_argument("--tiled", action="store_true") |
| | parser.add_argument("--tile_size", type=int, default=512) |
| | parser.add_argument("--tile_stride", type=int, default=256) |
| | parser.add_argument("--pos_prompt", type=str, default="") |
| | parser.add_argument("--neg_prompt", type=str, default="low quality, blurry, low-resolution, noisy, unsharp, weird textures") |
| | parser.add_argument("--cfg_scale", type=float, default=4.0) |
| | |
| | parser.add_argument("--input", type=str, required=True) |
| | parser.add_argument("--n_samples", type=int, default=1) |
| | |
| | parser.add_argument("--guidance", action="store_true") |
| | parser.add_argument("--g_loss", type=str, default="w_mse", choices=["mse", "w_mse"]) |
| | parser.add_argument("--g_scale", type=float, default=0.0) |
| | parser.add_argument("--g_start", type=int, default=1001) |
| | parser.add_argument("--g_stop", type=int, default=-1) |
| | parser.add_argument("--g_space", type=str, default="latent") |
| | parser.add_argument("--g_repeat", type=int, default=1) |
| | |
| | parser.add_argument("--output", type=str, required=True) |
| | |
| | parser.add_argument("--seed", type=int, default=231) |
| | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"]) |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | args.device = check_device(args.device) |
| | set_seed(args.seed) |
| | if args.version == "v1": |
| | V1InferenceLoop(args).run() |
| | else: |
| | supported_tasks = { |
| | "sr": BSRInferenceLoop, |
| | "dn": BIDInferenceLoop, |
| | "fr": BFRInferenceLoop, |
| | "fr_bg": UnAlignedBFRInferenceLoop |
| | } |
| | supported_tasks[args.task](args).run() |
| | print("done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|