| | import glob
|
| | from os import path
|
| | from paths import get_file_name, FastStableDiffusionPaths
|
| | from pathlib import Path
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class _lora_info:
|
| | def __init__(
|
| | self,
|
| | path: str,
|
| | weight: float,
|
| | ):
|
| | self.path = path
|
| | self.adapter_name = get_file_name(path)
|
| | self.weight = weight
|
| |
|
| | def __del__(self):
|
| | self.path = None
|
| | self.adapter_name = None
|
| |
|
| |
|
| | _loaded_loras = []
|
| | _current_pipeline = None
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_lora_weight(
|
| | pipeline,
|
| | lcm_diffusion_setting,
|
| | ):
|
| | if not lcm_diffusion_setting.lora.path:
|
| | raise Exception("Empty lora model path")
|
| |
|
| | if not path.exists(lcm_diffusion_setting.lora.path):
|
| | raise Exception("Lora model path is invalid")
|
| |
|
| |
|
| |
|
| | global _loaded_loras
|
| | global _current_pipeline
|
| | if pipeline != _current_pipeline:
|
| | for lora in _loaded_loras:
|
| | del lora
|
| | del _loaded_loras
|
| | _loaded_loras = []
|
| | _current_pipeline = pipeline
|
| |
|
| | current_lora = _lora_info(
|
| | lcm_diffusion_setting.lora.path,
|
| | lcm_diffusion_setting.lora.weight,
|
| | )
|
| | _loaded_loras.append(current_lora)
|
| |
|
| | if lcm_diffusion_setting.lora.enabled:
|
| | print(f"LoRA adapter name : {current_lora.adapter_name}")
|
| | pipeline.load_lora_weights(
|
| | FastStableDiffusionPaths.get_lora_models_path(),
|
| | weight_name=Path(lcm_diffusion_setting.lora.path).name,
|
| | local_files_only=True,
|
| | adapter_name=current_lora.adapter_name,
|
| | )
|
| | update_lora_weights(
|
| | pipeline,
|
| | lcm_diffusion_setting,
|
| | )
|
| |
|
| | if lcm_diffusion_setting.lora.fuse:
|
| | pipeline.fuse_lora()
|
| |
|
| |
|
| | def get_lora_models(root_dir: str):
|
| | lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
|
| | lora_models_map = {}
|
| | for file_path in lora_models:
|
| | lora_name = get_file_name(file_path)
|
| | if lora_name is not None:
|
| | lora_models_map[lora_name] = file_path
|
| | return lora_models_map
|
| |
|
| |
|
| |
|
| |
|
| | def get_active_lora_weights():
|
| | active_loras = []
|
| | for lora_info in _loaded_loras:
|
| | active_loras.append(
|
| | (
|
| | lora_info.adapter_name,
|
| | lora_info.weight,
|
| | )
|
| | )
|
| | return active_loras
|
| |
|
| |
|
| |
|
| |
|
| | def update_lora_weights(
|
| | pipeline,
|
| | lcm_diffusion_setting,
|
| | lora_weights=None,
|
| | ):
|
| | global _loaded_loras
|
| | global _current_pipeline
|
| | if pipeline != _current_pipeline:
|
| | print("Wrong pipeline when trying to update LoRA weights")
|
| | return
|
| | if lora_weights:
|
| | for idx, lora in enumerate(lora_weights):
|
| | if _loaded_loras[idx].adapter_name != lora[0]:
|
| | print("Wrong adapter name in LoRA enumeration!")
|
| | continue
|
| | _loaded_loras[idx].weight = lora[1]
|
| |
|
| | adapter_names = []
|
| | adapter_weights = []
|
| | if lcm_diffusion_setting.use_lcm_lora:
|
| | adapter_names.append("lcm")
|
| | adapter_weights.append(1.0)
|
| | for lora in _loaded_loras:
|
| | adapter_names.append(lora.adapter_name)
|
| | adapter_weights.append(lora.weight)
|
| | pipeline.set_adapters(
|
| | adapter_names,
|
| | adapter_weights=adapter_weights,
|
| | )
|
| | adapter_weights = zip(adapter_names, adapter_weights)
|
| | print(f"Adapters: {list(adapter_weights)}")
|
| |
|