| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Classifier-free guidance scaler.""" |
|
|
| import torch |
|
|
|
|
| class GuidanceScaler(object): |
| """Guidance scaler.""" |
|
|
| def __init__(self, **kwargs): |
| self.guidance_scale = kwargs.get("guidance_scale", 1) |
| self.guidance_trunc = kwargs.get("guidance_trunc", 0) |
| self.guidance_renorm = kwargs.get("guidance_renorm", 1) |
| self.image_guidance_scale = kwargs.get("image_guidance_scale", 0) |
| self.spatiotemporal_guidance_scale = kwargs.get("spatiotemporal_guidance_scale", 0) |
| self.min_guidance_scale = kwargs.get("min_guidance_scale", None) or self.guidance_scale |
| self.inc_guidance_scale = self.guidance_scale - self.min_guidance_scale |
|
|
| @property |
| def extra_pass(self) -> bool: |
| """Return if an additional (third) guidance pass is required.""" |
| return self.image_guidance_scale + self.spatiotemporal_guidance_scale > 0 |
|
|
| def clone(self): |
| """Return a deepcopy of current guidance scaler.""" |
| return GuidanceScaler(**self.__dict__) |
|
|
| def decay_guidance_scale(self, decay=0): |
| """Scale guidance scale according to decay.""" |
| self.guidance_scale = self.inc_guidance_scale * decay + self.min_guidance_scale |
|
|
| def expand(self, x: torch.Tensor, padding: torch.Tensor = None) -> torch.Tensor: |
| """Expand input tensor for guidance passes.""" |
| x = torch.stack([x] * (3 if self.extra_pass else 2)) if self.guidance_scale > 1 else x |
| x.__setitem__(1, padding) if self.image_guidance_scale and padding is not None else None |
| return x.flatten(0, 1) if self.guidance_scale > 1 else x |
|
|
| def expand_text(self, c: torch.Tensor) -> torch.Tensor: |
| """Expand text embedding tensor for guidance passes.""" |
| c = list(c.chunk(2)) if self.extra_pass else c |
| c.append(c[1]) if self.image_guidance_scale else None |
| c.append(c[0]) if self.spatiotemporal_guidance_scale else None |
| return torch.cat(c) if self.extra_pass else c |
|
|
| def maybe_disable(self, timestep, *args): |
| """Disable all guidance passes if matching truncation threshold.""" |
| if self.guidance_scale > 1 and self.guidance_trunc: |
| if float(timestep) < self.guidance_trunc: |
| self.guidance_scale = 1 |
| return [_.chunk(3 if self.extra_pass else 2)[0] for _ in args] |
| return args |
|
|
| def renorm(self, x, cond): |
| """Apply guidance renormalization to input logits.""" |
| if self.guidance_renorm >= 1: |
| return x |
| args = {"dim": tuple(range(1, len(x.shape))), "keepdim": True} |
| return x.mul_(cond.norm(**args).div_(x.norm(**args)).clamp(self.guidance_renorm, 1)) |
|
|
| def scale(self, x: torch.Tensor) -> torch.Tensor: |
| """Apply guidance passes to input logits.""" |
| if self.guidance_scale <= 1: |
| return x |
| if self.image_guidance_scale: |
| cond, uncond, imgcond = x.chunk(3) |
| x = self.renorm(uncond.add(cond.sub(imgcond).mul_(self.guidance_scale)), cond) |
| return x.add_(imgcond.sub_(uncond).mul_(self.image_guidance_scale)) |
| if self.spatiotemporal_guidance_scale: |
| cond, uncond, perturb = x.chunk(3) |
| x = self.renorm(uncond.add_(cond.sub(uncond).mul_(self.guidance_scale)), cond) |
| return x.add_(cond.sub_(perturb).mul_(self.spatiotemporal_guidance_scale)) |
| cond, uncond = x.chunk(2) |
| return self.renorm(uncond.add_(cond.sub(uncond).mul_(self.guidance_scale)), cond) |
|
|