| import json |
| import math |
| import os |
| import sys |
| import warnings |
|
|
| import torch |
| import numpy as np |
| from PIL import Image, ImageFilter, ImageOps |
| import random |
| import cv2 |
| from skimage import exposure |
| from typing import Any, Dict, List, Optional |
|
|
| import modules.sd_hijack |
| from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste |
| from modules.sd_hijack import model_hijack |
| from modules.shared import opts, cmd_opts, state |
| import modules.shared as shared |
| import modules.face_restoration |
| import modules.images as images |
| import modules.styles |
| import logging |
|
|
|
|
| |
| opt_C = 4 |
| opt_f = 8 |
|
|
|
|
| def setup_color_correction(image): |
| logging.info("Calibrating color correction.") |
| correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB) |
| return correction_target |
|
|
|
|
| def apply_color_correction(correction, image): |
| logging.info("Applying color correction.") |
| image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( |
| cv2.cvtColor( |
| np.asarray(image), |
| cv2.COLOR_RGB2LAB |
| ), |
| correction, |
| channel_axis=2 |
| ), cv2.COLOR_LAB2RGB).astype("uint8")) |
|
|
| return image |
|
|
|
|
| def apply_overlay(image, paste_loc, index, overlays): |
| if overlays is None or index >= len(overlays): |
| return image |
|
|
| overlay = overlays[index] |
|
|
| if paste_loc is not None: |
| x, y, w, h = paste_loc |
| base_image = Image.new('RGBA', (overlay.width, overlay.height)) |
| image = images.resize_image(1, image, w, h) |
| base_image.paste(image, (x, y)) |
| image = base_image |
|
|
| image = image.convert('RGBA') |
| image.alpha_composite(overlay) |
| image = image.convert('RGB') |
|
|
| return image |
|
|
|
|
| class StableDiffusionProcessing(): |
| """ |
| The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing |
| """ |
| def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): |
| if sampler_index is not None: |
| warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name") |
|
|
| self.sd_model = sd_model |
| self.outpath_samples: str = outpath_samples |
| self.outpath_grids: str = outpath_grids |
| self.prompt: str = prompt |
| self.prompt_for_display: str = None |
| self.negative_prompt: str = (negative_prompt or "") |
| self.styles: list = styles or [] |
| self.seed: int = seed |
| self.subseed: int = subseed |
| self.subseed_strength: float = subseed_strength |
| self.seed_resize_from_h: int = seed_resize_from_h |
| self.seed_resize_from_w: int = seed_resize_from_w |
| self.sampler_name: str = sampler_name |
| self.batch_size: int = batch_size |
| self.n_iter: int = n_iter |
| self.steps: int = steps |
| self.cfg_scale: float = cfg_scale |
| self.width: int = width |
| self.height: int = height |
| self.restore_faces: bool = restore_faces |
| self.tiling: bool = tiling |
| self.do_not_save_samples: bool = do_not_save_samples |
| self.do_not_save_grid: bool = do_not_save_grid |
| self.extra_generation_params: dict = extra_generation_params or {} |
| self.overlay_images = overlay_images |
| self.eta = eta |
| self.do_not_reload_embeddings = do_not_reload_embeddings |
| self.paste_to = None |
| self.color_corrections = None |
| self.denoising_strength: float = denoising_strength |
| self.sampler_noise_scheduler_override = None |
| self.ddim_discretize = ddim_discretize or opts.ddim_discretize |
| self.s_churn = s_churn or opts.s_churn |
| self.s_tmin = s_tmin or opts.s_tmin |
| self.s_tmax = s_tmax or float('inf') |
| self.s_noise = s_noise or opts.s_noise |
| self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} |
| self.is_using_inpainting_conditioning = False |
|
|
| if not seed_enable_extras: |
| self.subseed = -1 |
| self.subseed_strength = 0 |
| self.seed_resize_from_h = 0 |
| self.seed_resize_from_w = 0 |
|
|
| self.scripts = None |
| self.script_args = None |
| self.all_prompts = None |
| self.all_negative_prompts = None |
| self.all_seeds = None |
| self.all_subseeds = None |
|
|
| def txt2img_image_conditioning(self, x, width=None, height=None): |
| if self.sampler.conditioning_key not in {'hybrid', 'concat'}: |
| |
| |
| |
| return x.new_zeros(x.shape[0], 5, 1, 1) |
|
|
| self.is_using_inpainting_conditioning = True |
|
|
| height = height or self.height |
| width = width or self.width |
|
|
| |
| image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) |
| image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) |
|
|
| |
| image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) |
| image_conditioning = image_conditioning.to(x.dtype) |
|
|
| return image_conditioning |
|
|
| def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): |
| if self.sampler.conditioning_key not in {'hybrid', 'concat'}: |
| |
| return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) |
|
|
| self.is_using_inpainting_conditioning = True |
|
|
| |
| if image_mask is not None: |
| if torch.is_tensor(image_mask): |
| conditioning_mask = image_mask |
| else: |
| conditioning_mask = np.array(image_mask.convert("L")) |
| conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 |
| conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) |
|
|
| |
| conditioning_mask = torch.round(conditioning_mask) |
| else: |
| conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) |
|
|
| |
| |
| conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) |
| conditioning_image = torch.lerp( |
| source_image, |
| source_image * (1.0 - conditioning_mask), |
| getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) |
| ) |
| |
| |
| conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) |
|
|
| |
| conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) |
| conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) |
| image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) |
| image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype) |
|
|
| return image_conditioning |
|
|
| def init(self, all_prompts, all_seeds, all_subseeds): |
| pass |
|
|
| def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): |
| raise NotImplementedError() |
|
|
| def close(self): |
| self.sd_model = None |
| self.sampler = None |
|
|
|
|
| class Processed: |
| def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): |
| self.images = images_list |
| self.prompt = p.prompt |
| self.negative_prompt = p.negative_prompt |
| self.seed = seed |
| self.subseed = subseed |
| self.subseed_strength = p.subseed_strength |
| self.info = info |
| self.width = p.width |
| self.height = p.height |
| self.sampler_name = p.sampler_name |
| self.cfg_scale = p.cfg_scale |
| self.steps = p.steps |
| self.batch_size = p.batch_size |
| self.restore_faces = p.restore_faces |
| self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None |
| self.sd_model_hash = shared.sd_model.sd_model_hash |
| self.seed_resize_from_w = p.seed_resize_from_w |
| self.seed_resize_from_h = p.seed_resize_from_h |
| self.denoising_strength = getattr(p, 'denoising_strength', None) |
| self.extra_generation_params = p.extra_generation_params |
| self.index_of_first_image = index_of_first_image |
| self.styles = p.styles |
| self.job_timestamp = state.job_timestamp |
| self.clip_skip = opts.CLIP_stop_at_last_layers |
|
|
| self.eta = p.eta |
| self.ddim_discretize = p.ddim_discretize |
| self.s_churn = p.s_churn |
| self.s_tmin = p.s_tmin |
| self.s_tmax = p.s_tmax |
| self.s_noise = p.s_noise |
| self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override |
| self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] |
| self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] |
| self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 |
| self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 |
| self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning |
|
|
| self.all_prompts = all_prompts or p.all_prompts or [self.prompt] |
| self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt] |
| self.all_seeds = all_seeds or p.all_seeds or [self.seed] |
| self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] |
| self.infotexts = infotexts or [info] |
|
|
| def js(self): |
| obj = { |
| "prompt": self.all_prompts[0], |
| "all_prompts": self.all_prompts, |
| "negative_prompt": self.all_negative_prompts[0], |
| "all_negative_prompts": self.all_negative_prompts, |
| "seed": self.seed, |
| "all_seeds": self.all_seeds, |
| "subseed": self.subseed, |
| "all_subseeds": self.all_subseeds, |
| "subseed_strength": self.subseed_strength, |
| "width": self.width, |
| "height": self.height, |
| "sampler_name": self.sampler_name, |
| "cfg_scale": self.cfg_scale, |
| "steps": self.steps, |
| "batch_size": self.batch_size, |
| "restore_faces": self.restore_faces, |
| "face_restoration_model": self.face_restoration_model, |
| "sd_model_hash": self.sd_model_hash, |
| "seed_resize_from_w": self.seed_resize_from_w, |
| "seed_resize_from_h": self.seed_resize_from_h, |
| "denoising_strength": self.denoising_strength, |
| "extra_generation_params": self.extra_generation_params, |
| "index_of_first_image": self.index_of_first_image, |
| "infotexts": self.infotexts, |
| "styles": self.styles, |
| "job_timestamp": self.job_timestamp, |
| "clip_skip": self.clip_skip, |
| "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, |
| } |
|
|
| return json.dumps(obj) |
|
|
| def infotext(self, p: StableDiffusionProcessing, index): |
| return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) |
|
|
|
|
| |
| def slerp(val, low, high): |
| low_norm = low/torch.norm(low, dim=1, keepdim=True) |
| high_norm = high/torch.norm(high, dim=1, keepdim=True) |
| dot = (low_norm*high_norm).sum(1) |
|
|
| if dot.mean() > 0.9995: |
| return low * val + high * (1 - val) |
|
|
| omega = torch.acos(dot) |
| so = torch.sin(omega) |
| res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high |
| return res |
|
|
|
|
| def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): |
| xs = [] |
|
|
| |
| |
| |
| |
| if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): |
| sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] |
| else: |
| sampler_noises = None |
|
|
| for i, seed in enumerate(seeds): |
| noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) |
|
|
| subnoise = None |
| if subseeds is not None: |
| subseed = 0 if i >= len(subseeds) else subseeds[i] |
|
|
| subnoise = devices.randn(subseed, noise_shape) |
|
|
| |
| |
| |
| |
| noise = devices.randn(seed, noise_shape) |
|
|
| if subnoise is not None: |
| noise = slerp(subseed_strength, noise, subnoise) |
|
|
| if noise_shape != shape: |
| x = devices.randn(seed, shape) |
| dx = (shape[2] - noise_shape[2]) // 2 |
| dy = (shape[1] - noise_shape[1]) // 2 |
| w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx |
| h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy |
| tx = 0 if dx < 0 else dx |
| ty = 0 if dy < 0 else dy |
| dx = max(-dx, 0) |
| dy = max(-dy, 0) |
|
|
| x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] |
| noise = x |
|
|
| if sampler_noises is not None: |
| cnt = p.sampler.number_of_needed_noises(p) |
|
|
| if opts.eta_noise_seed_delta > 0: |
| torch.manual_seed(seed + opts.eta_noise_seed_delta) |
|
|
| for j in range(cnt): |
| sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) |
|
|
| xs.append(noise) |
|
|
| if sampler_noises is not None: |
| p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises] |
|
|
| x = torch.stack(xs).to(shared.device) |
| return x |
|
|
|
|
| def decode_first_stage(model, x): |
| with devices.autocast(disable=x.dtype == devices.dtype_vae): |
| x = model.decode_first_stage(x) |
|
|
| return x |
|
|
|
|
| def get_fixed_seed(seed): |
| if seed is None or seed == '' or seed == -1: |
| return int(random.randrange(4294967294)) |
|
|
| return seed |
|
|
|
|
| def fix_seed(p): |
| p.seed = get_fixed_seed(p.seed) |
| p.subseed = get_fixed_seed(p.subseed) |
|
|
|
|
| def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): |
| index = position_in_batch + iteration * p.batch_size |
|
|
| clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) |
|
|
| generation_params = { |
| "Steps": p.steps, |
| "Sampler": p.sampler_name, |
| "CFG scale": p.cfg_scale, |
| "Seed": all_seeds[index], |
| "Face restoration": (opts.face_restoration_model if p.restore_faces else None), |
| "Size": f"{p.width}x{p.height}", |
| "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), |
| "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), |
| "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), |
| "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), |
| "Batch size": (None if p.batch_size < 2 else p.batch_size), |
| "Batch pos": (None if p.batch_size < 2 else position_in_batch), |
| "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), |
| "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), |
| "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), |
| "Denoising strength": getattr(p, 'denoising_strength', None), |
| "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, |
| "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), |
| "Clip skip": None if clip_skip <= 1 else clip_skip, |
| "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, |
| } |
|
|
| generation_params.update(p.extra_generation_params) |
|
|
| generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) |
|
|
| negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else "" |
|
|
| return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() |
|
|
|
|
| def process_images(p: StableDiffusionProcessing) -> Processed: |
| stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} |
|
|
| try: |
| for k, v in p.override_settings.items(): |
| setattr(opts, k, v) |
| if k == 'sd_hypernetwork': shared.reload_hypernetworks() |
|
|
| res = process_images_inner(p) |
|
|
| finally: |
| for k, v in stored_opts.items(): |
| setattr(opts, k, v) |
| if k == 'sd_hypernetwork': shared.reload_hypernetworks() |
|
|
| return res |
|
|
|
|
| def process_images_inner(p: StableDiffusionProcessing) -> Processed: |
| """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" |
|
|
| if type(p.prompt) == list: |
| assert(len(p.prompt) > 0) |
| else: |
| assert p.prompt is not None |
|
|
| devices.torch_gc() |
|
|
| seed = get_fixed_seed(p.seed) |
| subseed = get_fixed_seed(p.subseed) |
|
|
| modules.sd_hijack.model_hijack.apply_circular(p.tiling) |
| modules.sd_hijack.model_hijack.clear_comments() |
|
|
| comments = {} |
|
|
| if type(p.prompt) == list: |
| p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt] |
| else: |
| p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)] |
|
|
| if type(p.negative_prompt) == list: |
| p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt] |
| else: |
| p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] |
|
|
| if type(seed) == list: |
| p.all_seeds = seed |
| else: |
| p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] |
|
|
| if type(subseed) == list: |
| p.all_subseeds = subseed |
| else: |
| p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] |
|
|
| def infotext(iteration=0, position_in_batch=0): |
| return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) |
|
|
| with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: |
| processed = Processed(p, [], p.seed, "") |
| file.write(processed.infotext(p, 0)) |
|
|
| if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: |
| model_hijack.embedding_db.load_textual_inversion_embeddings() |
|
|
| if p.scripts is not None: |
| p.scripts.process(p) |
|
|
| infotexts = [] |
| output_images = [] |
|
|
| with torch.no_grad(), p.sd_model.ema_scope(): |
| with devices.autocast(): |
| p.init(p.all_prompts, p.all_seeds, p.all_subseeds) |
|
|
| if state.job_count == -1: |
| state.job_count = p.n_iter |
|
|
| for n in range(p.n_iter): |
| if state.skipped: |
| state.skipped = False |
| |
| if state.interrupted: |
| break |
|
|
| prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] |
| negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] |
| seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] |
| subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] |
|
|
| if len(prompts) == 0: |
| break |
|
|
| if p.scripts is not None: |
| p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) |
|
|
| with devices.autocast(): |
| uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) |
| c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) |
|
|
| if len(model_hijack.comments) > 0: |
| for comment in model_hijack.comments: |
| comments[comment] = 1 |
|
|
| if p.n_iter > 1: |
| shared.state.job = f"Batch {n+1} out of {p.n_iter}" |
|
|
| with devices.autocast(): |
| samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) |
|
|
| samples_ddim = samples_ddim.to(devices.dtype_vae) |
| x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
| del samples_ddim |
|
|
| if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: |
| lowvram.send_everything_to_cpu() |
|
|
| devices.torch_gc() |
|
|
| if opts.filter_nsfw: |
| import modules.safety as safety |
| x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) |
|
|
| for i, x_sample in enumerate(x_samples_ddim): |
| x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) |
| x_sample = x_sample.astype(np.uint8) |
|
|
| if p.restore_faces: |
| if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: |
| images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") |
|
|
| devices.torch_gc() |
|
|
| x_sample = modules.face_restoration.restore_faces(x_sample) |
| devices.torch_gc() |
|
|
| image = Image.fromarray(x_sample) |
|
|
| if p.color_corrections is not None and i < len(p.color_corrections): |
| if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: |
| image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) |
| images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") |
| image = apply_color_correction(p.color_corrections[i], image) |
|
|
| image = apply_overlay(image, p.paste_to, i, p.overlay_images) |
|
|
| if opts.samples_save and not p.do_not_save_samples: |
| images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) |
|
|
| text = infotext(n, i) |
| infotexts.append(text) |
| if opts.enable_pnginfo: |
| image.info["parameters"] = text |
| output_images.append(image) |
|
|
| del x_samples_ddim |
|
|
| devices.torch_gc() |
|
|
| state.nextjob() |
|
|
| p.color_corrections = None |
|
|
| index_of_first_image = 0 |
| unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple |
| if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: |
| grid = images.image_grid(output_images, p.batch_size) |
|
|
| if opts.return_grid: |
| text = infotext() |
| infotexts.insert(0, text) |
| if opts.enable_pnginfo: |
| grid.info["parameters"] = text |
| output_images.insert(0, grid) |
| index_of_first_image = 1 |
|
|
| if opts.grid_save: |
| images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) |
|
|
| devices.torch_gc() |
|
|
| res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) |
|
|
| if p.scripts is not None: |
| p.scripts.postprocess(p, res) |
|
|
| return res |
|
|
|
|
| class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): |
| sampler = None |
|
|
| def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): |
| super().__init__(**kwargs) |
| self.enable_hr = enable_hr |
| self.denoising_strength = denoising_strength |
| self.firstphase_width = firstphase_width |
| self.firstphase_height = firstphase_height |
| self.truncate_x = 0 |
| self.truncate_y = 0 |
|
|
| def init(self, all_prompts, all_seeds, all_subseeds): |
| if self.enable_hr: |
| if state.job_count == -1: |
| state.job_count = self.n_iter * 2 |
| else: |
| state.job_count = state.job_count * 2 |
|
|
| self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" |
|
|
| if self.firstphase_width == 0 or self.firstphase_height == 0: |
| desired_pixel_count = 512 * 512 |
| actual_pixel_count = self.width * self.height |
| scale = math.sqrt(desired_pixel_count / actual_pixel_count) |
| self.firstphase_width = math.ceil(scale * self.width / 64) * 64 |
| self.firstphase_height = math.ceil(scale * self.height / 64) * 64 |
| firstphase_width_truncated = int(scale * self.width) |
| firstphase_height_truncated = int(scale * self.height) |
|
|
| else: |
|
|
| width_ratio = self.width / self.firstphase_width |
| height_ratio = self.height / self.firstphase_height |
|
|
| if width_ratio > height_ratio: |
| firstphase_width_truncated = self.firstphase_width |
| firstphase_height_truncated = self.firstphase_width * self.height / self.width |
| else: |
| firstphase_width_truncated = self.firstphase_height * self.width / self.height |
| firstphase_height_truncated = self.firstphase_height |
|
|
| self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f |
| self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f |
|
|
| def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): |
| self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) |
|
|
| if not self.enable_hr: |
| x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) |
| samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) |
| return samples |
|
|
| x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) |
| samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) |
|
|
| samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] |
|
|
| """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images""" |
| def save_intermediate(image, index): |
| if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: |
| return |
|
|
| if not isinstance(image, Image.Image): |
| image = sd_samplers.sample_to_image(image, index) |
|
|
| images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") |
|
|
| if opts.use_scale_latent_for_hires_fix: |
| for i in range(samples.shape[0]): |
| save_intermediate(samples, i) |
|
|
| samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") |
|
|
| |
| |
| if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: |
| image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) |
| else: |
| image_conditioning = self.txt2img_image_conditioning(samples) |
| else: |
| decoded_samples = decode_first_stage(self.sd_model, samples) |
| lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
| batch_images = [] |
| for i, x_sample in enumerate(lowres_samples): |
| x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) |
| x_sample = x_sample.astype(np.uint8) |
| image = Image.fromarray(x_sample) |
|
|
| save_intermediate(image, i) |
|
|
| image = images.resize_image(0, image, self.width, self.height) |
| image = np.array(image).astype(np.float32) / 255.0 |
| image = np.moveaxis(image, 2, 0) |
| batch_images.append(image) |
|
|
| decoded_samples = torch.from_numpy(np.array(batch_images)) |
| decoded_samples = decoded_samples.to(shared.device) |
| decoded_samples = 2. * decoded_samples - 1. |
|
|
| samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) |
|
|
| image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) |
|
|
| shared.state.nextjob() |
|
|
| self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) |
|
|
| noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) |
|
|
| |
| x = None |
| devices.torch_gc() |
|
|
| samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) |
|
|
| return samples |
|
|
|
|
| class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): |
| sampler = None |
|
|
| def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.init_images = init_images |
| self.resize_mode: int = resize_mode |
| self.denoising_strength: float = denoising_strength |
| self.init_latent = None |
| self.image_mask = mask |
| self.latent_mask = None |
| self.mask_for_overlay = None |
| self.mask_blur = mask_blur |
| self.inpainting_fill = inpainting_fill |
| self.inpaint_full_res = inpaint_full_res |
| self.inpaint_full_res_padding = inpaint_full_res_padding |
| self.inpainting_mask_invert = inpainting_mask_invert |
| self.mask = None |
| self.nmask = None |
| self.image_conditioning = None |
|
|
| def init(self, all_prompts, all_seeds, all_subseeds): |
| self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) |
| crop_region = None |
|
|
| image_mask = self.image_mask |
|
|
| if image_mask is not None: |
| image_mask = image_mask.convert('L') |
|
|
| if self.inpainting_mask_invert: |
| image_mask = ImageOps.invert(image_mask) |
|
|
| if self.mask_blur > 0: |
| image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) |
|
|
| if self.inpaint_full_res: |
| self.mask_for_overlay = image_mask |
| mask = image_mask.convert('L') |
| crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) |
| crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) |
| x1, y1, x2, y2 = crop_region |
|
|
| mask = mask.crop(crop_region) |
| image_mask = images.resize_image(2, mask, self.width, self.height) |
| self.paste_to = (x1, y1, x2-x1, y2-y1) |
| else: |
| image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) |
| np_mask = np.array(image_mask) |
| np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) |
| self.mask_for_overlay = Image.fromarray(np_mask) |
|
|
| self.overlay_images = [] |
|
|
| latent_mask = self.latent_mask if self.latent_mask is not None else image_mask |
|
|
| add_color_corrections = opts.img2img_color_correction and self.color_corrections is None |
| if add_color_corrections: |
| self.color_corrections = [] |
| imgs = [] |
| for img in self.init_images: |
| image = img.convert("RGB") |
|
|
| if crop_region is None: |
| image = images.resize_image(self.resize_mode, image, self.width, self.height) |
|
|
| if image_mask is not None: |
| image_masked = Image.new('RGBa', (image.width, image.height)) |
| image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) |
|
|
| self.overlay_images.append(image_masked.convert('RGBA')) |
|
|
| if crop_region is not None: |
| image = image.crop(crop_region) |
| image = images.resize_image(2, image, self.width, self.height) |
|
|
| if image_mask is not None: |
| if self.inpainting_fill != 1: |
| image = masking.fill(image, latent_mask) |
|
|
| if add_color_corrections: |
| self.color_corrections.append(setup_color_correction(image)) |
|
|
| image = np.array(image).astype(np.float32) / 255.0 |
| image = np.moveaxis(image, 2, 0) |
|
|
| imgs.append(image) |
|
|
| if len(imgs) == 1: |
| batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) |
| if self.overlay_images is not None: |
| self.overlay_images = self.overlay_images * self.batch_size |
|
|
| if self.color_corrections is not None and len(self.color_corrections) == 1: |
| self.color_corrections = self.color_corrections * self.batch_size |
|
|
| elif len(imgs) <= self.batch_size: |
| self.batch_size = len(imgs) |
| batch_images = np.array(imgs) |
| else: |
| raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") |
|
|
| image = torch.from_numpy(batch_images) |
| image = 2. * image - 1. |
| image = image.to(shared.device) |
|
|
| self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) |
|
|
| if image_mask is not None: |
| init_mask = latent_mask |
| latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) |
| latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 |
| latmask = latmask[0] |
| latmask = np.around(latmask) |
| latmask = np.tile(latmask[None], (4, 1, 1)) |
|
|
| self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) |
| self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) |
|
|
| |
| if self.inpainting_fill == 2: |
| self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask |
| elif self.inpainting_fill == 3: |
| self.init_latent = self.init_latent * self.mask |
|
|
| self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask) |
|
|
| def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): |
| x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) |
|
|
| samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) |
|
|
| if self.mask is not None: |
| samples = samples * self.nmask + self.init_latent * self.mask |
|
|
| del x |
| devices.torch_gc() |
|
|
| return samples |
|
|