| | from typing import Tuple, Set, List, Dict |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from model import ( |
| | ControlledUnetModel, ControlNet, |
| | AutoencoderKL, FrozenOpenCLIPEmbedder |
| | ) |
| | from utils.common import sliding_windows, count_vram_usage, gaussian_weights |
| |
|
| |
|
| | def disabled_train(self: nn.Module) -> nn.Module: |
| | """Overwrite model.train with this function to make sure train/eval mode |
| | does not change anymore.""" |
| | return self |
| |
|
| |
|
| | class ControlLDM(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | unet_cfg, |
| | vae_cfg, |
| | clip_cfg, |
| | controlnet_cfg, |
| | latent_scale_factor |
| | ): |
| | super().__init__() |
| | self.unet = ControlledUnetModel(**unet_cfg) |
| | self.vae = AutoencoderKL(**vae_cfg) |
| | self.clip = FrozenOpenCLIPEmbedder(**clip_cfg) |
| | self.controlnet = ControlNet(**controlnet_cfg) |
| | self.scale_factor = latent_scale_factor |
| | self.control_scales = [1.0] * 13 |
| |
|
| | @torch.no_grad() |
| | def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]: |
| | module_map = { |
| | "unet": "model.diffusion_model", |
| | "vae": "first_stage_model", |
| | "clip": "cond_stage_model", |
| | } |
| | modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)] |
| | used = set() |
| | for name, module in modules: |
| | init_sd = {} |
| | scratch_sd = module.state_dict() |
| | for key in scratch_sd: |
| | target_key = ".".join([module_map[name], key]) |
| | init_sd[key] = sd[target_key].clone() |
| | used.add(target_key) |
| | module.load_state_dict(init_sd, strict=True) |
| | unused = set(sd.keys()) - used |
| | |
| | |
| | for module in [self.vae, self.clip, self.unet]: |
| | module.eval() |
| | module.train = disabled_train |
| | for p in module.parameters(): |
| | p.requires_grad = False |
| | return unused |
| | |
| | @torch.no_grad() |
| | def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None: |
| | self.controlnet.load_state_dict(sd, strict=True) |
| |
|
| | @torch.no_grad() |
| | def load_controlnet_from_unet(self) -> Tuple[Set[str]]: |
| | unet_sd = self.unet.state_dict() |
| | scratch_sd = self.controlnet.state_dict() |
| | init_sd = {} |
| | init_with_new_zero = set() |
| | init_with_scratch = set() |
| | for key in scratch_sd: |
| | if key in unet_sd: |
| | this, target = scratch_sd[key], unet_sd[key] |
| | if this.size() == target.size(): |
| | init_sd[key] = target.clone() |
| | else: |
| | d_ic = this.size(1) - target.size(1) |
| | oc, _, h, w = this.size() |
| | zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype) |
| | init_sd[key] = torch.cat((target, zeros), dim=1) |
| | init_with_new_zero.add(key) |
| | else: |
| | init_sd[key] = scratch_sd[key].clone() |
| | init_with_scratch.add(key) |
| | self.controlnet.load_state_dict(init_sd, strict=True) |
| | return init_with_new_zero, init_with_scratch |
| | |
| | def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor: |
| | if sample: |
| | return self.vae.encode(image).sample() * self.scale_factor |
| | else: |
| | return self.vae.encode(image).mode() * self.scale_factor |
| | |
| | def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor: |
| | bs, _, h, w = image.shape |
| | z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device) |
| | count = torch.zeros_like(z, dtype=torch.float32) |
| | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] |
| | weights = torch.tensor(weights, dtype=torch.float32, device=image.device) |
| | tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8) |
| | for hi, hi_end, wi, wi_end in tiles: |
| | tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] |
| | z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights |
| | count[:, :, hi:hi_end, wi:wi_end] += weights |
| | z.div_(count) |
| | return z |
| | |
| | def vae_decode(self, z: torch.Tensor) -> torch.Tensor: |
| | return self.vae.decode(z / self.scale_factor) |
| | |
| | @count_vram_usage |
| | def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor: |
| | bs, _, h, w = z.shape |
| | image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device) |
| | count = torch.zeros_like(image, dtype=torch.float32) |
| | weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None] |
| | weights = torch.tensor(weights, dtype=torch.float32, device=z.device) |
| | tiles = sliding_windows(h, w, tile_size, tile_stride) |
| | for hi, hi_end, wi, wi_end in tiles: |
| | tile_z = z[:, :, hi:hi_end, wi:wi_end] |
| | image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights |
| | count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights |
| | image.div_(count) |
| | return image |
| |
|
| | def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]: |
| | return dict( |
| | c_txt=self.clip.encode(txt), |
| | c_img=self.vae_encode(clean * 2 - 1, sample=False) |
| | ) |
| | |
| | @count_vram_usage |
| | def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]: |
| | return dict( |
| | c_txt=self.clip.encode(txt), |
| | c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False) |
| | ) |
| |
|
| | def forward(self, x_noisy, t, cond): |
| | c_txt = cond["c_txt"] |
| | c_img = cond["c_img"] |
| | control = self.controlnet( |
| | x=x_noisy, hint=c_img, |
| | timesteps=t, context=c_txt |
| | ) |
| | control = [c * scale for c, scale in zip(control, self.control_scales)] |
| | eps = self.unet( |
| | x=x_noisy, timesteps=t, |
| | context=c_txt, control=control, only_mid_control=False |
| | ) |
| | return eps |
| |
|