| from collections import defaultdict |
| from typing import Union, List, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor, nn |
| import gradio as gr |
|
|
| from modules.processing import StableDiffusionProcessing |
| from modules import scripts |
|
|
| from scripts.cutofflib.sdhook import SDHook |
| from scripts.cutofflib.embedding import CLIP, generate_prompts, token_to_block |
| from scripts.cutofflib.utils import log, set_debug |
| from scripts.cutofflib.xyz import init_xyz |
|
|
| NAME = 'Cutoff' |
| PAD = '_</w>' |
|
|
| def check_neg(s: str, negative_prompt: str, all_negative_prompts: Union[List[str],None]): |
| if s == negative_prompt: |
| return True |
| |
| if all_negative_prompts is not None: |
| return s in all_negative_prompts |
| |
| return False |
|
|
| def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): |
| |
|
|
| inputs_are_torch = False |
| input_device = v0.device |
| if not isinstance(v0, np.ndarray): |
| inputs_are_torch = True |
| v0 = v0.cpu().numpy() |
| v1 = v1.cpu().numpy() |
|
|
| dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) |
| if np.abs(dot) > DOT_THRESHOLD: |
| v2 = (1 - t) * v0 + t * v1 |
| else: |
| theta_0 = np.arccos(dot) |
| sin_theta_0 = np.sin(theta_0) |
| theta_t = theta_0 * t |
| sin_theta_t = np.sin(theta_t) |
| s0 = np.sin(theta_0 - theta_t) / sin_theta_0 |
| s1 = sin_theta_t / sin_theta_0 |
| v2 = s0 * v0 + s1 * v1 |
|
|
| if inputs_are_torch: |
| v2 = torch.from_numpy(v2).to(input_device) |
|
|
| return v2 |
|
|
|
|
| class Hook(SDHook): |
| |
| def __init__( |
| self, |
| enabled: bool, |
| targets: List[str], |
| padding: Union[str,int], |
| weight: float, |
| disable_neg: bool, |
| strong: bool, |
| interpolate: str, |
| ): |
| super().__init__(enabled) |
| self.targets = targets |
| self.padding = padding |
| self.weight = float(weight) |
| self.disable_neg = disable_neg |
| self.strong = strong |
| self.intp = interpolate |
| |
| def interpolate(self, t1: Tensor, t2: Tensor, w): |
| if self.intp == 'lerp': |
| return torch.lerp(t1, t2, w) |
| else: |
| return slerp(w, t1, t2) |
| |
| def hook_clip(self, p: StableDiffusionProcessing, clip: nn.Module): |
| |
| skip = False |
| |
| def hook(mod: nn.Module, inputs: Tuple[List[str]], output: Tensor): |
| nonlocal skip |
| |
| if skip: |
| |
| return |
| |
| assert isinstance(mod, CLIP) |
| |
| prompts, *rest = inputs |
| assert len(prompts) == output.shape[0] |
| |
| |
| |
| |
| if self.disable_neg: |
| if all(check_neg(x, p.negative_prompt, p.all_negative_prompts) for x in prompts): |
| |
| return |
| |
| output = output.clone() |
| for pidx, prompt in enumerate(prompts): |
| tt = token_to_block(mod, prompt) |
| |
| cutoff = generate_prompts(mod, prompt, self.targets, self.padding) |
| switch_base = np.full_like(cutoff.sw, self.strong) |
| switch = np.full_like(cutoff.sw, True) |
| active = cutoff.active_blocks() |
| |
| prompt_to_tokens = defaultdict(lambda: []) |
| for tidx, (token, block_index) in enumerate(tt): |
| if block_index in active: |
| sw = switch.copy() |
| sw[block_index] = False |
| prompt = cutoff.text(sw) |
| else: |
| prompt = cutoff.text(switch_base) |
| prompt_to_tokens[prompt].append((tidx, token)) |
| |
| |
| |
| ks = list(prompt_to_tokens.keys()) |
| if len(ks) == 0: |
| |
| ks.append('') |
| |
| try: |
| |
| skip = True |
| vs = mod(ks) |
| finally: |
| skip = False |
| |
| tensor = output[pidx, :, :] |
| for k, t in zip(ks, vs): |
| assert tensor.shape == t.shape |
| for tidx, token in prompt_to_tokens[k]: |
| log(f'{tidx:03} {token.token:<16} {k}') |
| tensor[tidx, :] = self.interpolate(tensor[tidx,:], t[tidx,:], self.weight) |
| |
| return output |
| |
| self.hook_layer(clip, hook) |
| |
|
|
| class Script(scripts.Script): |
| |
| def __init__(self): |
| super().__init__() |
| self.last_hooker: Union[SDHook,None] = None |
|
|
| def title(self): |
| return NAME |
| |
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
| |
| def ui(self, is_img2img): |
| with gr.Accordion(NAME, open=False): |
| enabled = gr.Checkbox(label='Enabled', value=False) |
| targets = gr.Textbox(label='Target tokens (comma separated)', placeholder='red, blue') |
| weight = gr.Slider(minimum=-1.0, maximum=2.0, step=0.01, value=0.5, label='Weight') |
| with gr.Accordion('Details', open=False): |
| disable_neg = gr.Checkbox(value=True, label='Disable for Negative prompt.') |
| strong = gr.Checkbox(value=False, label='Cutoff strongly.') |
| padding = gr.Textbox(label='Padding token (ID or single token)') |
| lerp = gr.Radio(choices=['Lerp', 'SLerp'], value='Lerp', label='Interpolation method') |
| |
| debug = gr.Checkbox(value=False, label='Debug log') |
| debug.change(fn=set_debug, inputs=[debug], outputs=[]) |
| |
| return [ |
| enabled, |
| targets, |
| weight, |
| disable_neg, |
| strong, |
| padding, |
| lerp, |
| debug, |
| ] |
| |
| def process( |
| self, |
| p: StableDiffusionProcessing, |
| enabled: bool, |
| targets_: str, |
| weight: Union[float,int], |
| disable_neg: bool, |
| strong: bool, |
| padding: Union[str,int], |
| intp: str, |
| debug: bool, |
| ): |
| set_debug(debug) |
| |
| if self.last_hooker is not None: |
| self.last_hooker.__exit__(None, None, None) |
| self.last_hooker = None |
| |
| if not enabled: |
| return |
| |
| if targets_ is None or len(targets_) == 0: |
| return |
| |
| targets = [x.strip() for x in targets_.split(',')] |
| targets = [x for x in targets if len(x) != 0] |
| |
| if len(targets) == 0: |
| return |
| |
| if padding is None: |
| padding = PAD |
| elif isinstance(padding, str): |
| if len(padding) == 0: |
| padding = PAD |
| else: |
| try: |
| padding = int(padding) |
| except: |
| if not padding.endswith('</w>'): |
| padding += '</w>' |
| |
| weight = float(weight) |
| intp = intp.lower() |
| |
| self.last_hooker = Hook( |
| enabled=True, |
| targets=targets, |
| padding=padding, |
| weight=weight, |
| disable_neg=disable_neg, |
| strong=strong, |
| interpolate=intp, |
| ) |
| |
| self.last_hooker.setup(p) |
| self.last_hooker.__enter__() |
| |
| p.extra_generation_params.update({ |
| f'{NAME} enabled': enabled, |
| f'{NAME} targets': targets, |
| f'{NAME} padding': padding, |
| f'{NAME} weight': weight, |
| f'{NAME} disable_for_neg': disable_neg, |
| f'{NAME} strong': strong, |
| f'{NAME} interpolation': intp, |
| }) |
|
|
| init_xyz(Script, NAME) |
|
|