| | import os |
| | from argparse import ArgumentParser |
| | import warnings |
| |
|
| | from omegaconf import OmegaConf |
| | import torch |
| | from torch.nn import functional as F |
| | from torch.utils.data import DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | from torchvision.utils import make_grid |
| | from accelerate import Accelerator |
| | from accelerate.utils import set_seed |
| | from einops import rearrange |
| | from tqdm import tqdm |
| | import lpips |
| |
|
| | from model import SwinIR |
| | from utils.common import instantiate_from_config |
| |
|
| |
|
| | |
| | def rgb2ycbcr_pt(img, y_only=False): |
| | """Convert RGB images to YCbCr images (PyTorch version). |
| | |
| | It implements the ITU-R BT.601 conversion for standard-definition television. See more details in |
| | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. |
| | |
| | Args: |
| | img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. |
| | y_only (bool): Whether to only return Y channel. Default: False. |
| | |
| | Returns: |
| | (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. |
| | """ |
| | if y_only: |
| | weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) |
| | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 |
| | else: |
| | weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) |
| | bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) |
| | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias |
| |
|
| | out_img = out_img / 255. |
| | return out_img |
| |
|
| |
|
| | |
| | def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False): |
| | """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). |
| | |
| | Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio |
| | |
| | Args: |
| | img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). |
| | img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). |
| | crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. |
| | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. |
| | |
| | Returns: |
| | float: PSNR result. |
| | """ |
| |
|
| | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
| |
|
| | if crop_border != 0: |
| | img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] |
| | img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] |
| |
|
| | if test_y_channel: |
| | img = rgb2ycbcr_pt(img, y_only=True) |
| | img2 = rgb2ycbcr_pt(img2, y_only=True) |
| |
|
| | img = img.to(torch.float64) |
| | img2 = img2.to(torch.float64) |
| |
|
| | mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) |
| | return 10. * torch.log10(1. / (mse + 1e-8)) |
| |
|
| |
|
| | def main(args) -> None: |
| | |
| | accelerator = Accelerator(split_batches=True) |
| | set_seed(231) |
| | device = accelerator.device |
| | cfg = OmegaConf.load(args.config) |
| |
|
| | |
| | if accelerator.is_local_main_process: |
| | exp_dir = cfg.train.exp_dir |
| | os.makedirs(exp_dir, exist_ok=True) |
| | ckpt_dir = os.path.join(exp_dir, "checkpoints") |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| | print(f"Experiment directory created at {exp_dir}") |
| |
|
| | |
| | swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
| | if cfg.train.resume: |
| | swinir.load_state_dict(torch.load(cfg.train.resume, map_location="cpu"), strict=True) |
| | if accelerator.is_local_main_process: |
| | print(f"strictly load weight from checkpoint: {cfg.train.resume}") |
| | else: |
| | if accelerator.is_local_main_process: |
| | print("initialize from scratch") |
| | |
| | |
| | opt = torch.optim.AdamW( |
| | swinir.parameters(), lr=cfg.train.learning_rate, |
| | weight_decay=0 |
| | ) |
| | |
| | |
| | dataset = instantiate_from_config(cfg.dataset.train) |
| | loader = DataLoader( |
| | dataset=dataset, batch_size=cfg.train.batch_size, |
| | num_workers=cfg.train.num_workers, |
| | shuffle=True, drop_last=True |
| | ) |
| | val_dataset = instantiate_from_config(cfg.dataset.val) |
| | val_loader = DataLoader( |
| | dataset=val_dataset, batch_size=cfg.train.batch_size, |
| | num_workers=cfg.train.num_workers, |
| | shuffle=False, drop_last=False |
| | ) |
| | if accelerator.is_local_main_process: |
| | print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
| |
|
| | |
| | swinir.train().to(device) |
| | swinir, opt, loader, val_loader = accelerator.prepare(swinir, opt, loader, val_loader) |
| | pure_swinir = accelerator.unwrap_model(swinir) |
| |
|
| | |
| | global_step = 0 |
| | max_steps = cfg.train.train_steps |
| | step_loss = [] |
| | epoch = 0 |
| | epoch_loss = [] |
| | with warnings.catch_warnings(): |
| | |
| | warnings.simplefilter("ignore") |
| | lpips_model = lpips.LPIPS(net="alex", verbose=accelerator.is_local_main_process).eval().to(device) |
| | if accelerator.is_local_main_process: |
| | writer = SummaryWriter(exp_dir) |
| | print(f"Training for {max_steps} steps...") |
| | |
| | while global_step < max_steps: |
| | pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader)) |
| | for gt, lq, _ in loader: |
| | gt = rearrange((gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device) |
| | lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device) |
| | pred = swinir(lq) |
| | loss = F.mse_loss(input=pred, target=gt, reduction="sum") |
| |
|
| | opt.zero_grad() |
| | accelerator.backward(loss) |
| | opt.step() |
| | accelerator.wait_for_everyone() |
| |
|
| | global_step += 1 |
| | step_loss.append(loss.item()) |
| | epoch_loss.append(loss.item()) |
| | pbar.update(1) |
| | pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
| |
|
| | |
| | if global_step % cfg.train.log_every == 0: |
| | |
| | avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item() |
| | step_loss.clear() |
| | if accelerator.is_local_main_process: |
| | writer.add_scalar("train/loss_step", avg_loss, global_step) |
| |
|
| | |
| | if global_step % cfg.train.ckpt_every == 0: |
| | if accelerator.is_local_main_process: |
| | checkpoint = pure_swinir.state_dict() |
| | ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
| | torch.save(checkpoint, ckpt_path) |
| |
|
| | if global_step % cfg.train.image_every == 0 or global_step == 1: |
| | swinir.eval() |
| | N = 12 |
| | log_gt, log_lq = gt[:N], lq[:N] |
| | with torch.no_grad(): |
| | log_pred = swinir(log_lq) |
| | if accelerator.is_local_main_process: |
| | for tag, image in [ |
| | ("image/pred", log_pred), |
| | ("image/gt", log_gt), |
| | ("image/lq", log_lq), |
| | ]: |
| | writer.add_image(tag, make_grid(image, nrow=4), global_step) |
| | swinir.train() |
| |
|
| | |
| | if global_step % cfg.train.val_every == 0: |
| | swinir.eval() |
| | val_loss = [] |
| | val_lpips = [] |
| | val_psnr = [] |
| | val_pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", |
| | total=len(val_loader), leave=False, desc="Validation") |
| | |
| | for val_gt, val_lq, _ in val_loader: |
| | val_gt = rearrange((val_gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device) |
| | val_lq = rearrange(val_lq, "b h w c -> b c h w").contiguous().float().to(device) |
| | with torch.no_grad(): |
| | |
| | val_pred = swinir(val_lq) |
| | |
| | val_loss.append(F.mse_loss(input=val_pred, target=val_gt, reduction="sum").item()) |
| | val_lpips.append(lpips_model(val_pred, val_gt, normalize=True).mean().item()) |
| | val_psnr.append(calculate_psnr_pt(val_pred, val_gt, crop_border=0).mean().item()) |
| | val_pbar.update(1) |
| | val_pbar.close() |
| | avg_val_loss = accelerator.gather(torch.tensor(val_loss, device=device).unsqueeze(0)).mean().item() |
| | avg_val_lpips = accelerator.gather(torch.tensor(val_lpips, device=device).unsqueeze(0)).mean().item() |
| | avg_val_psnr = accelerator.gather(torch.tensor(val_psnr, device=device).unsqueeze(0)).mean().item() |
| | if accelerator.is_local_main_process: |
| | for tag, val in [ |
| | ("val/loss", avg_val_loss), |
| | ("val/lpips", avg_val_lpips), |
| | ("val/psnr", avg_val_psnr) |
| | ]: |
| | writer.add_scalar(tag, val, global_step) |
| | swinir.train() |
| | |
| | accelerator.wait_for_everyone() |
| |
|
| | if global_step == max_steps: |
| | break |
| | |
| | pbar.close() |
| | epoch += 1 |
| | avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item() |
| | epoch_loss.clear() |
| | if accelerator.is_local_main_process: |
| | writer.add_scalar("train/loss_epoch", avg_epoch_loss, global_step) |
| |
|
| | if accelerator.is_local_main_process: |
| | print("done!") |
| | writer.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True) |
| | args = parser.parse_args() |
| | main(args) |
| |
|