Spaces:
Running on Zero
Running on Zero
Set up Style My Portrait branding and app
Browse files- .gitattributes +1 -0
- README.md +2 -2
- SpaceGrotesk-Bold.ttf +0 -0
- app.py +271 -0
- flux2_klein_kv.patch +1565 -0
- reference.jpg +3 -0
- requirements.txt +8 -0
- styles.json +289 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
reference.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
---
|
| 2 |
title: Style My Portrait
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.13.0
|
|
|
|
| 1 |
---
|
| 2 |
title: Style My Portrait
|
| 3 |
+
emoji: 🖼️
|
| 4 |
+
colorFrom: yellow
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.13.0
|
SpaceGrotesk-Bold.ttf
ADDED
|
Binary file (86.5 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import subprocess
|
| 5 |
+
import spaces
|
| 6 |
+
|
| 7 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 8 |
+
|
| 9 |
+
def apply_patch():
|
| 10 |
+
import diffusers
|
| 11 |
+
site_packages = os.path.dirname(diffusers.__file__)
|
| 12 |
+
patch_file = os.path.join(os.path.dirname(__file__), "flux2_klein_kv.patch")
|
| 13 |
+
if os.path.exists(patch_file):
|
| 14 |
+
result = subprocess.run(
|
| 15 |
+
["patch", "-p2", "--forward", "--batch"],
|
| 16 |
+
cwd=os.path.dirname(site_packages),
|
| 17 |
+
stdin=open(patch_file),
|
| 18 |
+
capture_output=True,
|
| 19 |
+
text=True,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
apply_patch()
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from PIL import Image
|
| 27 |
+
from diffusers.pipelines.flux2.pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
| 28 |
+
import gradio as gr
|
| 29 |
+
|
| 30 |
+
APP_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 31 |
+
STYLES_PATH = os.path.join(APP_DIR, "styles.json")
|
| 32 |
+
REFERENCE_PATH = os.path.join(APP_DIR, "reference.jpg")
|
| 33 |
+
MODEL_ID = "black-forest-labs/FLUX.2-klein-9b-kv"
|
| 34 |
+
MANDATORY_PROMPT = (
|
| 35 |
+
"You must preserve the exact identity, facial features, body proportions, and pose of the subject. "
|
| 36 |
+
"Keep the original camera framing, object positions, and full scene layout unchanged. "
|
| 37 |
+
"Apply only the requested visual style while maintaining total consistency and realism."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
dtype = torch.bfloat16
|
| 41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 43 |
+
|
| 44 |
+
with open(STYLES_PATH, "r", encoding="utf-8") as f:
|
| 45 |
+
STYLES = json.load(f)
|
| 46 |
+
|
| 47 |
+
pipe = Flux2KleinKVPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype, token=HF_TOKEN)
|
| 48 |
+
pipe.to("cuda")
|
| 49 |
+
|
| 50 |
+
PREVIEW_CACHE = {}
|
| 51 |
+
|
| 52 |
+
def prepare_image(image, max_size=1024):
|
| 53 |
+
iw, ih = image.size
|
| 54 |
+
ar = iw / ih
|
| 55 |
+
if ar >= 1:
|
| 56 |
+
w = max_size
|
| 57 |
+
h = round(max_size / ar / 16) * 16
|
| 58 |
+
else:
|
| 59 |
+
h = max_size
|
| 60 |
+
w = round(max_size * ar / 16) * 16
|
| 61 |
+
w, h = max(256, min(max_size, w)), max(256, min(max_size, h))
|
| 62 |
+
return image.resize((w, h), Image.LANCZOS), w, h
|
| 63 |
+
|
| 64 |
+
def build_prompt(style_prompt):
|
| 65 |
+
return f"{MANDATORY_PROMPT} {style_prompt}"
|
| 66 |
+
|
| 67 |
+
@spaces.GPU(duration=85)
|
| 68 |
+
def transform_image(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
|
| 69 |
+
if image is None:
|
| 70 |
+
raise gr.Error("Upload a photo first!")
|
| 71 |
+
style = next((s for s in STYLES if s["id"] == style_id), None)
|
| 72 |
+
if not style:
|
| 73 |
+
raise gr.Error("Select a valid style!")
|
| 74 |
+
if randomize_seed:
|
| 75 |
+
seed = random.randint(0, MAX_SEED)
|
| 76 |
+
original_resized, w, h = prepare_image(image)
|
| 77 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 78 |
+
progress(0.2, desc=f"Applying {style['name']} style...")
|
| 79 |
+
result = pipe(
|
| 80 |
+
prompt=build_prompt(style["prompt"]),
|
| 81 |
+
image=[original_resized],
|
| 82 |
+
height=h,
|
| 83 |
+
width=w,
|
| 84 |
+
num_inference_steps=num_steps,
|
| 85 |
+
generator=generator,
|
| 86 |
+
).images[0]
|
| 87 |
+
progress(0.9, desc="Preparing before/after slider...")
|
| 88 |
+
return (original_resized, result), seed, original_resized, result
|
| 89 |
+
|
| 90 |
+
@spaces.GPU(duration=120)
|
| 91 |
+
def generate_previews(progress=gr.Progress(track_tqdm=True)):
|
| 92 |
+
global PREVIEW_CACHE
|
| 93 |
+
if PREVIEW_CACHE:
|
| 94 |
+
return PREVIEW_CACHE
|
| 95 |
+
if not os.path.exists(REFERENCE_PATH):
|
| 96 |
+
PREVIEW_CACHE = {s["id"]: None for s in STYLES}
|
| 97 |
+
return PREVIEW_CACHE
|
| 98 |
+
ref_img = Image.open(REFERENCE_PATH).convert("RGB")
|
| 99 |
+
_, w, h = prepare_image(ref_img)
|
| 100 |
+
for i, style in enumerate(STYLES):
|
| 101 |
+
progress(i / len(STYLES), desc=f"Generating preview: {style['name']}")
|
| 102 |
+
generator = torch.Generator(device=device).manual_seed(style.get("preview_seed", 42))
|
| 103 |
+
try:
|
| 104 |
+
res = pipe(
|
| 105 |
+
prompt=build_prompt(style["prompt"]),
|
| 106 |
+
image=[ref_img],
|
| 107 |
+
height=h,
|
| 108 |
+
width=w,
|
| 109 |
+
num_inference_steps=4,
|
| 110 |
+
generator=generator,
|
| 111 |
+
).images[0]
|
| 112 |
+
PREVIEW_CACHE[style["id"]] = res
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error generating preview for {style['name']}: {e}")
|
| 115 |
+
PREVIEW_CACHE[style["id"]] = None
|
| 116 |
+
return PREVIEW_CACHE
|
| 117 |
+
|
| 118 |
+
def load_previews_to_gallery():
|
| 119 |
+
previews = generate_previews()
|
| 120 |
+
gallery_items = []
|
| 121 |
+
for style in STYLES:
|
| 122 |
+
img = previews.get(style["id"])
|
| 123 |
+
if img:
|
| 124 |
+
gallery_items.append((img, f"{style['emoji']} {style['name']}"))
|
| 125 |
+
else:
|
| 126 |
+
placeholder = Image.new("RGB", (512, 512), (20, 18, 30))
|
| 127 |
+
gallery_items.append((placeholder, f"{style['emoji']} {style['name']}"))
|
| 128 |
+
return gallery_items
|
| 129 |
+
|
| 130 |
+
def select_style_from_gallery(evt: gr.SelectData):
|
| 131 |
+
style = STYLES[evt.index]
|
| 132 |
+
return style["id"]
|
| 133 |
+
|
| 134 |
+
css = """
|
| 135 |
+
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
|
| 136 |
+
.gradio-container { background: radial-gradient(circle at top left, #fff4e8 0%, #f7ddc0 32%, #d79a72 68%, #8f4f3d 100%) !important; max-width: 1120px !important; margin: 0 auto !important; font-family: 'Space Grotesk', sans-serif !important; }
|
| 137 |
+
.main-title h1 { text-align: center; color: #5a2416 !important; font-size: 2.7em !important; font-weight: 700 !important; letter-spacing: -1px; margin-bottom: 0 !important; }
|
| 138 |
+
.subtitle p { text-align: center; color: #6f3a2a !important; font-size: 1.05em !important; margin-top: 0 !important; }
|
| 139 |
+
.step-guide p { text-align: center; color: #6b3527 !important; font-size: 1.03em !important; margin: 6px 0 10px !important; }
|
| 140 |
+
.credit-line p { text-align: center; color: #7b3f2d !important; font-size: 0.98em !important; margin: 4px 0 20px !important; }
|
| 141 |
+
.credit-line a { color: #8f2d1f !important; text-decoration: none !important; font-weight: 700 !important; }
|
| 142 |
+
.credit-line a:hover { text-decoration: underline !important; }
|
| 143 |
+
#style-gallery { border: 1px solid #b96b4c !important; border-radius: 12px !important; background: #fff8f0 !important; }
|
| 144 |
+
#style-gallery .grid-wrap { gap: 8px !important; }
|
| 145 |
+
#style-gallery .thumbnail-item { border: 2px solid transparent !important; border-radius: 10px !important; transition: all 0.2s ease !important; }
|
| 146 |
+
#style-gallery .thumbnail-item:hover { border-color: #cf6e43aa !important; transform: translateY(-2px); }
|
| 147 |
+
#style-gallery .thumbnail-item.selected,
|
| 148 |
+
#style-gallery .thumbnail-item[aria-selected="true"] {
|
| 149 |
+
border-color: #bf5f34 !important;
|
| 150 |
+
box-shadow: 0 0 0 2px #cf6e4380, 0 10px 26px #8f4f3d40 !important;
|
| 151 |
+
transform: translateY(-2px) scale(1.02) !important;
|
| 152 |
+
}
|
| 153 |
+
#input-img { border: 2px dashed #c57d57 !important; border-radius: 12px !important; background: #fff8f0 !important; min-height: 320px; }
|
| 154 |
+
#output-slider { border: 1px solid #c57d57 !important; border-radius: 12px !important; background: #fff8f0 !important; overflow: hidden !important; }
|
| 155 |
+
#go-btn { background: linear-gradient(135deg, #8f2d1f, #d76a3f) !important; color: #fff8f0 !important; font-weight: 700 !important; font-size: 1.15em !important; min-height: 52px !important; border: none !important; border-radius: 12px !important; box-shadow: 0 4px 20px #8f2d1f33; transition: all 0.2s ease !important; }
|
| 156 |
+
#go-btn:hover { box-shadow: 0 6px 26px #8f2d1f4d; transform: translateY(-2px); }
|
| 157 |
+
#dl-btn { background: #fff3e6 !important; color: #8f2d1f !important; border: 1px solid #c57d57 !important; border-radius: 10px !important; }
|
| 158 |
+
.progress-bar { background-color: #d76a3f !important; }
|
| 159 |
+
.progress-bar-wrap { background-color: #f3d8bd !important; }
|
| 160 |
+
* { --neutral-50: #fff8f0 !important; --neutral-100: #fff3e6 !important; --neutral-200: #e7bb96 !important; }
|
| 161 |
+
.dark { --body-background-fill: #f7ddc0; }
|
| 162 |
+
footer { display: none !important; }
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
with gr.Blocks(title="Style My Portrait", css=css, theme=gr.themes.Base(
|
| 166 |
+
primary_hue=gr.themes.colors.orange,
|
| 167 |
+
secondary_hue=gr.themes.colors.red,
|
| 168 |
+
neutral_hue=gr.themes.colors.stone,
|
| 169 |
+
font=gr.themes.GoogleFont("Space Grotesk"),
|
| 170 |
+
).set(
|
| 171 |
+
body_background_fill="#f7ddc0",
|
| 172 |
+
body_background_fill_dark="#f7ddc0",
|
| 173 |
+
block_background_fill="#fff8f0",
|
| 174 |
+
block_background_fill_dark="#fff8f0",
|
| 175 |
+
block_border_color="#c57d57",
|
| 176 |
+
block_border_color_dark="#c57d57",
|
| 177 |
+
block_label_text_color="#8f2d1f",
|
| 178 |
+
block_label_text_color_dark="#8f2d1f",
|
| 179 |
+
block_title_text_color="#8f2d1f",
|
| 180 |
+
block_title_text_color_dark="#8f2d1f",
|
| 181 |
+
body_text_color="#5a2416",
|
| 182 |
+
body_text_color_dark="#5a2416",
|
| 183 |
+
button_primary_background_fill="#d76a3f",
|
| 184 |
+
button_primary_background_fill_dark="#d76a3f",
|
| 185 |
+
button_primary_text_color="#fff8f0",
|
| 186 |
+
button_primary_text_color_dark="#fff8f0",
|
| 187 |
+
input_background_fill="#fff3e6",
|
| 188 |
+
input_background_fill_dark="#fff3e6",
|
| 189 |
+
input_border_color="#c57d57",
|
| 190 |
+
input_border_color_dark="#c57d57",
|
| 191 |
+
border_color_accent="#d76a3f",
|
| 192 |
+
border_color_accent_dark="#d76a3f",
|
| 193 |
+
border_color_primary="#c57d57",
|
| 194 |
+
border_color_primary_dark="#c57d57",
|
| 195 |
+
background_fill_secondary="#fff8f0",
|
| 196 |
+
background_fill_secondary_dark="#fff8f0",
|
| 197 |
+
background_fill_primary="#f7ddc0",
|
| 198 |
+
background_fill_primary_dark="#f7ddc0",
|
| 199 |
+
shadow_drop="none",
|
| 200 |
+
shadow_drop_lg="none",
|
| 201 |
+
slider_color="#d76a3f",
|
| 202 |
+
slider_color_dark="#d76a3f",
|
| 203 |
+
checkbox_background_color="#fff3e6",
|
| 204 |
+
checkbox_background_color_dark="#fff3e6",
|
| 205 |
+
checkbox_background_color_selected="#d76a3f",
|
| 206 |
+
checkbox_background_color_selected_dark="#d76a3f",
|
| 207 |
+
)) as demo:
|
| 208 |
+
|
| 209 |
+
gr.Markdown("# Style My Portrait", elem_classes="main-title")
|
| 210 |
+
gr.Markdown("Turn your portrait into polished visual styles with FLUX.2 Klein", elem_classes="subtitle")
|
| 211 |
+
gr.Markdown("**1)** Upload your photo • **2)** Pick a style from the gallery • **3)** Click transform", elem_classes="step-guide")
|
| 212 |
+
gr.Markdown(
|
| 213 |
+
"Built by [@artificialguybr](https://twitter.com/artificialguybr) • Explore more image editing and image generation prompts at [artificialguy.com](https://artificialguy.com) and [findgoodprompt.com](https://findgoodprompt.com)",
|
| 214 |
+
elem_classes="credit-line",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
selected_style_id = gr.State(STYLES[0]["id"])
|
| 218 |
+
original_state = gr.State(None)
|
| 219 |
+
enhanced_state = gr.State(None)
|
| 220 |
+
|
| 221 |
+
with gr.Row(equal_height=True):
|
| 222 |
+
with gr.Column(scale=1):
|
| 223 |
+
input_image = gr.Image(label="📸 Upload & Preview", type="pil", elem_id="input-img")
|
| 224 |
+
with gr.Accordion("⚙️ Settings", open=False):
|
| 225 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 226 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 227 |
+
num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=20, step=1, value=4)
|
| 228 |
+
go_btn = gr.Button("🚀 Transform", elem_id="go-btn", variant="primary")
|
| 229 |
+
|
| 230 |
+
with gr.Column(scale=1):
|
| 231 |
+
gr.Markdown("### 🎨 Gallery", elem_classes="subtitle")
|
| 232 |
+
style_gallery = gr.Gallery(
|
| 233 |
+
label="Style Gallery",
|
| 234 |
+
show_label=False,
|
| 235 |
+
elem_id="style-gallery",
|
| 236 |
+
columns=[4], rows=[2],
|
| 237 |
+
height=360,
|
| 238 |
+
object_fit="cover",
|
| 239 |
+
allow_preview=False,
|
| 240 |
+
value=load_previews_to_gallery()
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
style_gallery.select(
|
| 244 |
+
fn=select_style_from_gallery,
|
| 245 |
+
inputs=[],
|
| 246 |
+
outputs=[selected_style_id]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
gr.Markdown("### 🖼️ Result", elem_classes="subtitle")
|
| 250 |
+
with gr.Row():
|
| 251 |
+
output_image = gr.ImageSlider(label="Before / After", type="pil", elem_id="output-slider", slider_position=50)
|
| 252 |
+
with gr.Row():
|
| 253 |
+
dl_btn = gr.DownloadButton("📥 Download Image", elem_id="dl-btn", visible=False)
|
| 254 |
+
|
| 255 |
+
def on_generate(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
|
| 256 |
+
comparison, seed, orig, enh = transform_image(image, style_id, seed, randomize_seed, num_steps, progress)
|
| 257 |
+
return comparison, seed, orig, enh, gr.update(visible=True, value=enh)
|
| 258 |
+
|
| 259 |
+
go_btn.click(
|
| 260 |
+
fn=on_generate,
|
| 261 |
+
inputs=[input_image, selected_style_id, seed, randomize_seed, num_inference_steps],
|
| 262 |
+
outputs=[output_image, seed, original_state, enhanced_state, dl_btn],
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
input_image.change(
|
| 266 |
+
fn=lambda: (None, gr.update(visible=False), None, None),
|
| 267 |
+
inputs=[],
|
| 268 |
+
outputs=[output_image, dl_btn, original_state, enhanced_state],
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
demo.launch(ssr_mode=False)
|
flux2_klein_kv.patch
ADDED
|
@@ -0,0 +1,1565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
|
| 2 |
+
index 546fbe57b..0be7b8166 100644
|
| 3 |
+
--- a/src/diffusers/__init__.py
|
| 4 |
+
+++ b/src/diffusers/__init__.py
|
| 5 |
+
@@ -510,6 +510,7 @@ else:
|
| 6 |
+
"EasyAnimateControlPipeline",
|
| 7 |
+
"EasyAnimateInpaintPipeline",
|
| 8 |
+
"EasyAnimatePipeline",
|
| 9 |
+
+ "Flux2KleinKVPipeline",
|
| 10 |
+
"Flux2KleinPipeline",
|
| 11 |
+
"Flux2Pipeline",
|
| 12 |
+
"FluxControlImg2ImgPipeline",
|
| 13 |
+
@@ -1266,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 14 |
+
EasyAnimateControlPipeline,
|
| 15 |
+
EasyAnimateInpaintPipeline,
|
| 16 |
+
EasyAnimatePipeline,
|
| 17 |
+
+ Flux2KleinKVPipeline,
|
| 18 |
+
Flux2KleinPipeline,
|
| 19 |
+
Flux2Pipeline,
|
| 20 |
+
FluxControlImg2ImgPipeline,
|
| 21 |
+
diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py
|
| 22 |
+
index f77498c74..d2ecba583 100644
|
| 23 |
+
--- a/src/diffusers/models/transformers/transformer_flux2.py
|
| 24 |
+
+++ b/src/diffusers/models/transformers/transformer_flux2.py
|
| 25 |
+
@@ -40,6 +40,193 @@ from ..normalization import AdaLayerNormContinuous
|
| 26 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
+class Flux2KVLayerCache:
|
| 30 |
+
+ """Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
|
| 31 |
+
+
|
| 32 |
+
+ Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step.
|
| 33 |
+
+ Tensor format: (batch_size, num_ref_tokens, num_heads, head_dim).
|
| 34 |
+
+ """
|
| 35 |
+
+
|
| 36 |
+
+ def __init__(self):
|
| 37 |
+
+ self.k_ref: torch.Tensor | None = None
|
| 38 |
+
+ self.v_ref: torch.Tensor | None = None
|
| 39 |
+
+
|
| 40 |
+
+ def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
|
| 41 |
+
+ """Store reference token K/V."""
|
| 42 |
+
+ self.k_ref = k_ref
|
| 43 |
+
+ self.v_ref = v_ref
|
| 44 |
+
+
|
| 45 |
+
+ def get(self) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
+
+ """Retrieve cached reference token K/V."""
|
| 47 |
+
+ if self.k_ref is None:
|
| 48 |
+
+ raise RuntimeError("KV cache has not been populated yet.")
|
| 49 |
+
+ return self.k_ref, self.v_ref
|
| 50 |
+
+
|
| 51 |
+
+ def clear(self):
|
| 52 |
+
+ self.k_ref = None
|
| 53 |
+
+ self.v_ref = None
|
| 54 |
+
+
|
| 55 |
+
+
|
| 56 |
+
+class Flux2KVCache:
|
| 57 |
+
+ """Container for all layers' reference-token KV caches.
|
| 58 |
+
+
|
| 59 |
+
+ Holds separate cache lists for double-stream and single-stream transformer blocks.
|
| 60 |
+
+ """
|
| 61 |
+
+
|
| 62 |
+
+ def __init__(self, num_double_layers: int, num_single_layers: int):
|
| 63 |
+
+ self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
|
| 64 |
+
+ self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
|
| 65 |
+
+ self.num_ref_tokens: int = 0
|
| 66 |
+
+
|
| 67 |
+
+ def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
|
| 68 |
+
+ return self.double_block_caches[layer_idx]
|
| 69 |
+
+
|
| 70 |
+
+ def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
|
| 71 |
+
+ return self.single_block_caches[layer_idx]
|
| 72 |
+
+
|
| 73 |
+
+ def clear(self):
|
| 74 |
+
+ for cache in self.double_block_caches:
|
| 75 |
+
+ cache.clear()
|
| 76 |
+
+ for cache in self.single_block_caches:
|
| 77 |
+
+ cache.clear()
|
| 78 |
+
+ self.num_ref_tokens = 0
|
| 79 |
+
+
|
| 80 |
+
+
|
| 81 |
+
+def _flux2_kv_causal_attention(
|
| 82 |
+
+ query: torch.Tensor,
|
| 83 |
+
+ key: torch.Tensor,
|
| 84 |
+
+ value: torch.Tensor,
|
| 85 |
+
+ num_txt_tokens: int,
|
| 86 |
+
+ num_ref_tokens: int,
|
| 87 |
+
+ kv_cache: Flux2KVLayerCache | None = None,
|
| 88 |
+
+ backend=None,
|
| 89 |
+
+) -> torch.Tensor:
|
| 90 |
+
+ """Causal attention for KV caching where reference tokens only self-attend.
|
| 91 |
+
+
|
| 92 |
+
+ All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
|
| 93 |
+
+
|
| 94 |
+
+ Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens,
|
| 95 |
+
+ ref tokens only attend to themselves.
|
| 96 |
+
+ With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected between txt and img.
|
| 97 |
+
+ """
|
| 98 |
+
+ # No ref tokens and no cache — standard full attention
|
| 99 |
+
+ if num_ref_tokens == 0 and kv_cache is None:
|
| 100 |
+
+ return dispatch_attention_fn(query, key, value, backend=backend)
|
| 101 |
+
+
|
| 102 |
+
+ if kv_cache is not None:
|
| 103 |
+
+ # Cached mode: inject ref K/V between txt and img
|
| 104 |
+
+ k_ref, v_ref = kv_cache.get()
|
| 105 |
+
+
|
| 106 |
+
+ k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
|
| 107 |
+
+ v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
|
| 108 |
+
+
|
| 109 |
+
+ return dispatch_attention_fn(query, k_all, v_all, backend=backend)
|
| 110 |
+
+
|
| 111 |
+
+ # Extract mode: ref tokens self-attend, txt+img attend to all
|
| 112 |
+
+ ref_start = num_txt_tokens
|
| 113 |
+
+ ref_end = num_txt_tokens + num_ref_tokens
|
| 114 |
+
+
|
| 115 |
+
+ q_txt = query[:, :ref_start]
|
| 116 |
+
+ q_ref = query[:, ref_start:ref_end]
|
| 117 |
+
+ q_img = query[:, ref_end:]
|
| 118 |
+
+
|
| 119 |
+
+ k_txt = key[:, :ref_start]
|
| 120 |
+
+ k_ref = key[:, ref_start:ref_end]
|
| 121 |
+
+ k_img = key[:, ref_end:]
|
| 122 |
+
+
|
| 123 |
+
+ v_txt = value[:, :ref_start]
|
| 124 |
+
+ v_ref = value[:, ref_start:ref_end]
|
| 125 |
+
+ v_img = value[:, ref_end:]
|
| 126 |
+
+
|
| 127 |
+
+ # txt+img attend to all tokens
|
| 128 |
+
+ q_txt_img = torch.cat([q_txt, q_img], dim=1)
|
| 129 |
+
+ k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
|
| 130 |
+
+ v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
|
| 131 |
+
+ attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
|
| 132 |
+
+ attn_txt = attn_txt_img[:, :ref_start]
|
| 133 |
+
+ attn_img = attn_txt_img[:, ref_start:]
|
| 134 |
+
+
|
| 135 |
+
+ # ref tokens self-attend only
|
| 136 |
+
+ attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
|
| 137 |
+
+
|
| 138 |
+
+ return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
|
| 139 |
+
+
|
| 140 |
+
+
|
| 141 |
+
+def _blend_mod_params(
|
| 142 |
+
+ img_params: tuple[torch.Tensor, ...],
|
| 143 |
+
+ ref_params: tuple[torch.Tensor, ...],
|
| 144 |
+
+ num_ref: int,
|
| 145 |
+
+ seq_len: int,
|
| 146 |
+
+) -> tuple[torch.Tensor, ...]:
|
| 147 |
+
+ """Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
|
| 148 |
+
+ blended = []
|
| 149 |
+
+ for im, rm in zip(img_params, ref_params):
|
| 150 |
+
+ if im.ndim == 2:
|
| 151 |
+
+ im = im.unsqueeze(1)
|
| 152 |
+
+ rm = rm.unsqueeze(1)
|
| 153 |
+
+ B = im.shape[0]
|
| 154 |
+
+ blended.append(
|
| 155 |
+
+ torch.cat(
|
| 156 |
+
+ [rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
|
| 157 |
+
+ dim=1,
|
| 158 |
+
+ )
|
| 159 |
+
+ )
|
| 160 |
+
+ return tuple(blended)
|
| 161 |
+
+
|
| 162 |
+
+
|
| 163 |
+
+def _blend_double_block_mods(
|
| 164 |
+
+ img_mod: torch.Tensor,
|
| 165 |
+
+ ref_mod: torch.Tensor,
|
| 166 |
+
+ num_ref: int,
|
| 167 |
+
+ seq_len: int,
|
| 168 |
+
+) -> torch.Tensor:
|
| 169 |
+
+ """Blend double-block image-stream modulations for a [ref, img] sequence layout.
|
| 170 |
+
+
|
| 171 |
+
+ Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is
|
| 172 |
+
+ compatible with `Flux2Modulation.split(mod, 2)`.
|
| 173 |
+
+ """
|
| 174 |
+
+ img_mods = Flux2Modulation.split(img_mod, 2)
|
| 175 |
+
+ ref_mods = Flux2Modulation.split(ref_mod, 2)
|
| 176 |
+
+
|
| 177 |
+
+ all_params = []
|
| 178 |
+
+ for img_set, ref_set in zip(img_mods, ref_mods):
|
| 179 |
+
+ blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
|
| 180 |
+
+ all_params.extend(blended)
|
| 181 |
+
+ return torch.cat(all_params, dim=-1)
|
| 182 |
+
+
|
| 183 |
+
+
|
| 184 |
+
+def _blend_single_block_mods(
|
| 185 |
+
+ single_mod: torch.Tensor,
|
| 186 |
+
+ ref_mod: torch.Tensor,
|
| 187 |
+
+ num_txt: int,
|
| 188 |
+
+ num_ref: int,
|
| 189 |
+
+ seq_len: int,
|
| 190 |
+
+) -> torch.Tensor:
|
| 191 |
+
+ """Blend single-block modulations for a [txt, ref, img] sequence layout.
|
| 192 |
+
+
|
| 193 |
+
+ Takes raw modulation tensors and returns a blended raw tensor compatible with
|
| 194 |
+
+ `Flux2Modulation.split(mod, 1)`.
|
| 195 |
+
+ """
|
| 196 |
+
+ img_params = Flux2Modulation.split(single_mod, 1)[0]
|
| 197 |
+
+ ref_params = Flux2Modulation.split(ref_mod, 1)[0]
|
| 198 |
+
+
|
| 199 |
+
+ blended = []
|
| 200 |
+
+ for im, rm in zip(img_params, ref_params):
|
| 201 |
+
+ if im.ndim == 2:
|
| 202 |
+
+ im = im.unsqueeze(1)
|
| 203 |
+
+ rm = rm.unsqueeze(1)
|
| 204 |
+
+ B = im.shape[0]
|
| 205 |
+
+ im_expanded = im.expand(B, seq_len, -1)
|
| 206 |
+
+ rm_expanded = rm.expand(B, num_ref, -1)
|
| 207 |
+
+ blended.append(
|
| 208 |
+
+ torch.cat(
|
| 209 |
+
+ [im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
|
| 210 |
+
+ dim=1,
|
| 211 |
+
+ )
|
| 212 |
+
+ )
|
| 213 |
+
+ return torch.cat(blended, dim=-1)
|
| 214 |
+
+
|
| 215 |
+
+
|
| 216 |
+
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 217 |
+
query = attn.to_q(hidden_states)
|
| 218 |
+
key = attn.to_k(hidden_states)
|
| 219 |
+
@@ -181,9 +368,105 @@ class Flux2AttnProcessor:
|
| 220 |
+
return hidden_states
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
+class Flux2KVAttnProcessor:
|
| 224 |
+
+ """
|
| 225 |
+
+ Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
|
| 226 |
+
+
|
| 227 |
+
+ When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal
|
| 228 |
+
+ attention is used (ref tokens self-attend only, txt+img attend to all).
|
| 229 |
+
+ When `kv_cache_mode` is "cached", cached ref K/V are injected during attention.
|
| 230 |
+
+ When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
|
| 231 |
+
+ """
|
| 232 |
+
+
|
| 233 |
+
+ _attention_backend = None
|
| 234 |
+
+ _parallel_config = None
|
| 235 |
+
+
|
| 236 |
+
+ def __init__(self):
|
| 237 |
+
+ if not hasattr(F, "scaled_dot_product_attention"):
|
| 238 |
+
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 239 |
+
+
|
| 240 |
+
+ def __call__(
|
| 241 |
+
+ self,
|
| 242 |
+
+ attn: "Flux2Attention",
|
| 243 |
+
+ hidden_states: torch.Tensor,
|
| 244 |
+
+ encoder_hidden_states: torch.Tensor = None,
|
| 245 |
+
+ attention_mask: torch.Tensor | None = None,
|
| 246 |
+
+ image_rotary_emb: torch.Tensor | None = None,
|
| 247 |
+
+ kv_cache: Flux2KVLayerCache | None = None,
|
| 248 |
+
+ kv_cache_mode: str | None = None,
|
| 249 |
+
+ num_ref_tokens: int = 0,
|
| 250 |
+
+ ) -> torch.Tensor:
|
| 251 |
+
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 252 |
+
+ attn, hidden_states, encoder_hidden_states
|
| 253 |
+
+ )
|
| 254 |
+
+
|
| 255 |
+
+ query = query.unflatten(-1, (attn.heads, -1))
|
| 256 |
+
+ key = key.unflatten(-1, (attn.heads, -1))
|
| 257 |
+
+ value = value.unflatten(-1, (attn.heads, -1))
|
| 258 |
+
+
|
| 259 |
+
+ query = attn.norm_q(query)
|
| 260 |
+
+ key = attn.norm_k(key)
|
| 261 |
+
+
|
| 262 |
+
+ if attn.added_kv_proj_dim is not None:
|
| 263 |
+
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 264 |
+
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 265 |
+
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 266 |
+
+
|
| 267 |
+
+ encoder_query = attn.norm_added_q(encoder_query)
|
| 268 |
+
+ encoder_key = attn.norm_added_k(encoder_key)
|
| 269 |
+
+
|
| 270 |
+
+ query = torch.cat([encoder_query, query], dim=1)
|
| 271 |
+
+ key = torch.cat([encoder_key, key], dim=1)
|
| 272 |
+
+ value = torch.cat([encoder_value, value], dim=1)
|
| 273 |
+
+
|
| 274 |
+
+ if image_rotary_emb is not None:
|
| 275 |
+
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 276 |
+
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 277 |
+
+
|
| 278 |
+
+ num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
| 279 |
+
+
|
| 280 |
+
+ # Extract ref K/V from the combined sequence
|
| 281 |
+
+ if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
| 282 |
+
+ ref_start = num_txt_tokens
|
| 283 |
+
+ ref_end = num_txt_tokens + num_ref_tokens
|
| 284 |
+
+ kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
| 285 |
+
+
|
| 286 |
+
+ # Dispatch attention
|
| 287 |
+
+ if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
| 288 |
+
+ hidden_states = _flux2_kv_causal_attention(
|
| 289 |
+
+ query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
| 290 |
+
+ )
|
| 291 |
+
+ elif kv_cache_mode == "cached" and kv_cache is not None:
|
| 292 |
+
+ hidden_states = _flux2_kv_causal_attention(
|
| 293 |
+
+ query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
| 294 |
+
+ )
|
| 295 |
+
+ else:
|
| 296 |
+
+ hidden_states = dispatch_attention_fn(
|
| 297 |
+
+ query, key, value, attn_mask=attention_mask,
|
| 298 |
+
+ backend=self._attention_backend, parallel_config=self._parallel_config,
|
| 299 |
+
+ )
|
| 300 |
+
+
|
| 301 |
+
+ hidden_states = hidden_states.flatten(2, 3)
|
| 302 |
+
+ hidden_states = hidden_states.to(query.dtype)
|
| 303 |
+
+
|
| 304 |
+
+ if encoder_hidden_states is not None:
|
| 305 |
+
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 306 |
+
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 307 |
+
+ )
|
| 308 |
+
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 309 |
+
+
|
| 310 |
+
+ hidden_states = attn.to_out[0](hidden_states)
|
| 311 |
+
+ hidden_states = attn.to_out[1](hidden_states)
|
| 312 |
+
+
|
| 313 |
+
+ if encoder_hidden_states is not None:
|
| 314 |
+
+ return hidden_states, encoder_hidden_states
|
| 315 |
+
+ else:
|
| 316 |
+
+ return hidden_states
|
| 317 |
+
+
|
| 318 |
+
+
|
| 319 |
+
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
| 320 |
+
_default_processor_cls = Flux2AttnProcessor
|
| 321 |
+
- _available_processors = [Flux2AttnProcessor]
|
| 322 |
+
+ _available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
@@ -312,6 +595,86 @@ class Flux2ParallelSelfAttnProcessor:
|
| 327 |
+
return hidden_states
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
+class Flux2KVParallelSelfAttnProcessor:
|
| 331 |
+
+ """
|
| 332 |
+
+ Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
|
| 333 |
+
+
|
| 334 |
+
+ When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used.
|
| 335 |
+
+ When `kv_cache_mode` is "cached", cached ref K/V are injected during attention.
|
| 336 |
+
+ When no KV args are provided, behaves identically to `Flux2ParallelSelfAttnProcessor`.
|
| 337 |
+
+ """
|
| 338 |
+
+
|
| 339 |
+
+ _attention_backend = None
|
| 340 |
+
+ _parallel_config = None
|
| 341 |
+
+
|
| 342 |
+
+ def __init__(self):
|
| 343 |
+
+ if not hasattr(F, "scaled_dot_product_attention"):
|
| 344 |
+
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 345 |
+
+
|
| 346 |
+
+ def __call__(
|
| 347 |
+
+ self,
|
| 348 |
+
+ attn: "Flux2ParallelSelfAttention",
|
| 349 |
+
+ hidden_states: torch.Tensor,
|
| 350 |
+
+ attention_mask: torch.Tensor | None = None,
|
| 351 |
+
+ image_rotary_emb: torch.Tensor | None = None,
|
| 352 |
+
+ kv_cache: Flux2KVLayerCache | None = None,
|
| 353 |
+
+ kv_cache_mode: str | None = None,
|
| 354 |
+
+ num_txt_tokens: int = 0,
|
| 355 |
+
+ num_ref_tokens: int = 0,
|
| 356 |
+
+ ) -> torch.Tensor:
|
| 357 |
+
+ # Parallel in (QKV + MLP in) projection
|
| 358 |
+
+ hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
|
| 359 |
+
+ qkv, mlp_hidden_states = torch.split(
|
| 360 |
+
+ hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 361 |
+
+ )
|
| 362 |
+
+
|
| 363 |
+
+ query, key, value = qkv.chunk(3, dim=-1)
|
| 364 |
+
+
|
| 365 |
+
+ query = query.unflatten(-1, (attn.heads, -1))
|
| 366 |
+
+ key = key.unflatten(-1, (attn.heads, -1))
|
| 367 |
+
+ value = value.unflatten(-1, (attn.heads, -1))
|
| 368 |
+
+
|
| 369 |
+
+ query = attn.norm_q(query)
|
| 370 |
+
+ key = attn.norm_k(key)
|
| 371 |
+
+
|
| 372 |
+
+ if image_rotary_emb is not None:
|
| 373 |
+
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 374 |
+
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 375 |
+
+
|
| 376 |
+
+ # Extract ref K/V from the combined sequence
|
| 377 |
+
+ if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
| 378 |
+
+ ref_start = num_txt_tokens
|
| 379 |
+
+ ref_end = num_txt_tokens + num_ref_tokens
|
| 380 |
+
+ kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
| 381 |
+
+
|
| 382 |
+
+ # Dispatch attention
|
| 383 |
+
+ if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
| 384 |
+
+ attn_output = _flux2_kv_causal_attention(
|
| 385 |
+
+ query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
| 386 |
+
+ )
|
| 387 |
+
+ elif kv_cache_mode == "cached" and kv_cache is not None:
|
| 388 |
+
+ attn_output = _flux2_kv_causal_attention(
|
| 389 |
+
+ query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
| 390 |
+
+ )
|
| 391 |
+
+ else:
|
| 392 |
+
+ attn_output = dispatch_attention_fn(
|
| 393 |
+
+ query, key, value, attn_mask=attention_mask,
|
| 394 |
+
+ backend=self._attention_backend, parallel_config=self._parallel_config,
|
| 395 |
+
+ )
|
| 396 |
+
+
|
| 397 |
+
+ attn_output = attn_output.flatten(2, 3)
|
| 398 |
+
+ attn_output = attn_output.to(query.dtype)
|
| 399 |
+
+
|
| 400 |
+
+ # Handle the feedforward (FF) logic
|
| 401 |
+
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 402 |
+
+
|
| 403 |
+
+ # Concatenate and parallel output projection
|
| 404 |
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
|
| 405 |
+
+ hidden_states = attn.to_out(hidden_states)
|
| 406 |
+
+
|
| 407 |
+
+ return hidden_states
|
| 408 |
+
+
|
| 409 |
+
+
|
| 410 |
+
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
| 411 |
+
"""
|
| 412 |
+
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
| 413 |
+
@@ -322,7 +685,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
| 417 |
+
- _available_processors = [Flux2ParallelSelfAttnProcessor]
|
| 418 |
+
+ _available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
|
| 419 |
+
# Does not support QKV fusion as the QKV projections are always fused
|
| 420 |
+
_supports_qkv_fusion = False
|
| 421 |
+
|
| 422 |
+
@@ -780,6 +1143,8 @@ class Flux2Transformer2DModel(
|
| 423 |
+
|
| 424 |
+
self.gradient_checkpointing = False
|
| 425 |
+
|
| 426 |
+
+ _skip_keys = ["kv_cache"]
|
| 427 |
+
+
|
| 428 |
+
@apply_lora_scale("joint_attention_kwargs")
|
| 429 |
+
def forward(
|
| 430 |
+
self,
|
| 431 |
+
@@ -791,19 +1156,21 @@ class Flux2Transformer2DModel(
|
| 432 |
+
guidance: torch.Tensor = None,
|
| 433 |
+
joint_attention_kwargs: dict[str, Any] | None = None,
|
| 434 |
+
return_dict: bool = True,
|
| 435 |
+
+ kv_cache: "Flux2KVCache | None" = None,
|
| 436 |
+
+ kv_cache_mode: str | None = None,
|
| 437 |
+
+ num_ref_tokens: int = 0,
|
| 438 |
+
+ ref_fixed_timestep: float = 0.0,
|
| 439 |
+
) -> torch.Tensor | Transformer2DModelOutput:
|
| 440 |
+
"""
|
| 441 |
+
- The [`FluxTransformer2DModel`] forward method.
|
| 442 |
+
+ The [`Flux2Transformer2DModel`] forward method.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 446 |
+
Input `hidden_states`.
|
| 447 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 448 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 449 |
+
- timestep ( `torch.LongTensor`):
|
| 450 |
+
+ timestep (`torch.LongTensor`):
|
| 451 |
+
Used to indicate denoising step.
|
| 452 |
+
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 453 |
+
- A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 454 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 455 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 456 |
+
`self.processor` in
|
| 457 |
+
@@ -811,13 +1178,23 @@ class Flux2Transformer2DModel(
|
| 458 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 459 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 460 |
+
tuple.
|
| 461 |
+
+ kv_cache (`Flux2KVCache`, *optional*):
|
| 462 |
+
+ KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created
|
| 463 |
+
+ and returned. When "cached", the provided cache is used to inject ref K/V during attention.
|
| 464 |
+
+ kv_cache_mode (`str`, *optional*):
|
| 465 |
+
+ One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V).
|
| 466 |
+
+ When `None`, standard forward pass without KV caching.
|
| 467 |
+
+ num_ref_tokens (`int`, defaults to `0`):
|
| 468 |
+
+ Number of reference image tokens prepended to `hidden_states` (only used when
|
| 469 |
+
+ `kv_cache_mode="extract"`).
|
| 470 |
+
+ ref_fixed_timestep (`float`, defaults to `0.0`):
|
| 471 |
+
+ Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 475 |
+
- `tuple` where the first element is the sample tensor.
|
| 476 |
+
+ `tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
|
| 477 |
+
+ populated `Flux2KVCache`.
|
| 478 |
+
"""
|
| 479 |
+
- # 0. Handle input arguments
|
| 480 |
+
-
|
| 481 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 482 |
+
|
| 483 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 484 |
+
@@ -832,13 +1209,33 @@ class Flux2Transformer2DModel(
|
| 485 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 486 |
+
single_stream_mod = self.single_stream_modulation(temb)
|
| 487 |
+
|
| 488 |
+
+ # KV extract mode: create cache and blend modulations for ref tokens
|
| 489 |
+
+ if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
| 490 |
+
+ num_img_tokens = hidden_states.shape[1] # includes ref tokens
|
| 491 |
+
+
|
| 492 |
+
+ kv_cache = Flux2KVCache(
|
| 493 |
+
+ num_double_layers=len(self.transformer_blocks),
|
| 494 |
+
+ num_single_layers=len(self.single_transformer_blocks),
|
| 495 |
+
+ )
|
| 496 |
+
+ kv_cache.num_ref_tokens = num_ref_tokens
|
| 497 |
+
+
|
| 498 |
+
+ # Ref tokens use a fixed timestep for modulation
|
| 499 |
+
+ ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
|
| 500 |
+
+ ref_temb = self.time_guidance_embed(ref_timestep, guidance)
|
| 501 |
+
+
|
| 502 |
+
+ ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
|
| 503 |
+
+ ref_single_mod = self.single_stream_modulation(ref_temb)
|
| 504 |
+
+
|
| 505 |
+
+ # Blend double block img modulation: [ref_mod, img_mod]
|
| 506 |
+
+ double_stream_mod_img = _blend_double_block_mods(
|
| 507 |
+
+ double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
|
| 508 |
+
+ )
|
| 509 |
+
+
|
| 510 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 511 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 512 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 513 |
+
|
| 514 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 515 |
+
- # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 516 |
+
- # text prompts of differents lengths. Is this a use case we want to support?
|
| 517 |
+
if img_ids.ndim == 3:
|
| 518 |
+
img_ids = img_ids[0]
|
| 519 |
+
if txt_ids.ndim == 3:
|
| 520 |
+
@@ -851,8 +1248,29 @@ class Flux2Transformer2DModel(
|
| 521 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
- # 4. Double Stream Transformer Blocks
|
| 525 |
+
+ # 4. Build joint_attention_kwargs with KV cache info
|
| 526 |
+
+ if kv_cache_mode == "extract":
|
| 527 |
+
+ kv_attn_kwargs = {
|
| 528 |
+
+ **(joint_attention_kwargs or {}),
|
| 529 |
+
+ "kv_cache": None,
|
| 530 |
+
+ "kv_cache_mode": "extract",
|
| 531 |
+
+ "num_ref_tokens": num_ref_tokens,
|
| 532 |
+
+ }
|
| 533 |
+
+ elif kv_cache_mode == "cached" and kv_cache is not None:
|
| 534 |
+
+ kv_attn_kwargs = {
|
| 535 |
+
+ **(joint_attention_kwargs or {}),
|
| 536 |
+
+ "kv_cache": None,
|
| 537 |
+
+ "kv_cache_mode": "cached",
|
| 538 |
+
+ "num_ref_tokens": kv_cache.num_ref_tokens,
|
| 539 |
+
+ }
|
| 540 |
+
+ else:
|
| 541 |
+
+ kv_attn_kwargs = joint_attention_kwargs
|
| 542 |
+
+
|
| 543 |
+
+ # 5. Double Stream Transformer Blocks
|
| 544 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 545 |
+
+ if kv_cache_mode is not None and kv_cache is not None:
|
| 546 |
+
+ kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
|
| 547 |
+
+
|
| 548 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 549 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 550 |
+
block,
|
| 551 |
+
@@ -861,7 +1279,7 @@ class Flux2Transformer2DModel(
|
| 552 |
+
double_stream_mod_img,
|
| 553 |
+
double_stream_mod_txt,
|
| 554 |
+
concat_rotary_emb,
|
| 555 |
+
- joint_attention_kwargs,
|
| 556 |
+
+ kv_attn_kwargs,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
encoder_hidden_states, hidden_states = block(
|
| 560 |
+
@@ -870,13 +1288,30 @@ class Flux2Transformer2DModel(
|
| 561 |
+
temb_mod_img=double_stream_mod_img,
|
| 562 |
+
temb_mod_txt=double_stream_mod_txt,
|
| 563 |
+
image_rotary_emb=concat_rotary_emb,
|
| 564 |
+
- joint_attention_kwargs=joint_attention_kwargs,
|
| 565 |
+
+ joint_attention_kwargs=kv_attn_kwargs,
|
| 566 |
+
)
|
| 567 |
+
+
|
| 568 |
+
# Concatenate text and image streams for single-block inference
|
| 569 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 570 |
+
|
| 571 |
+
- # 5. Single Stream Transformer Blocks
|
| 572 |
+
+ # Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
|
| 573 |
+
+ if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
| 574 |
+
+ total_single_len = hidden_states.shape[1]
|
| 575 |
+
+ single_stream_mod = _blend_single_block_mods(
|
| 576 |
+
+ single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
|
| 577 |
+
+ )
|
| 578 |
+
+
|
| 579 |
+
+ # Build single-block KV kwargs (single blocks need num_txt_tokens)
|
| 580 |
+
+ if kv_cache_mode is not None:
|
| 581 |
+
+ kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
|
| 582 |
+
+ else:
|
| 583 |
+
+ kv_attn_kwargs_single = kv_attn_kwargs
|
| 584 |
+
+
|
| 585 |
+
+ # 6. Single Stream Transformer Blocks
|
| 586 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 587 |
+
+ if kv_cache_mode is not None and kv_cache is not None:
|
| 588 |
+
+ kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
|
| 589 |
+
+
|
| 590 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 591 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 592 |
+
block,
|
| 593 |
+
@@ -884,7 +1319,7 @@ class Flux2Transformer2DModel(
|
| 594 |
+
None,
|
| 595 |
+
single_stream_mod,
|
| 596 |
+
concat_rotary_emb,
|
| 597 |
+
- joint_attention_kwargs,
|
| 598 |
+
+ kv_attn_kwargs_single,
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
hidden_states = block(
|
| 602 |
+
@@ -892,15 +1327,24 @@ class Flux2Transformer2DModel(
|
| 603 |
+
encoder_hidden_states=None,
|
| 604 |
+
temb_mod=single_stream_mod,
|
| 605 |
+
image_rotary_emb=concat_rotary_emb,
|
| 606 |
+
- joint_attention_kwargs=joint_attention_kwargs,
|
| 607 |
+
+ joint_attention_kwargs=kv_attn_kwargs_single,
|
| 608 |
+
)
|
| 609 |
+
- # Remove text tokens from concatenated stream
|
| 610 |
+
- hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
| 611 |
+
|
| 612 |
+
- # 6. Output layers
|
| 613 |
+
+ # Remove text tokens (and ref tokens in extract mode) from concatenated stream
|
| 614 |
+
+ if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
| 615 |
+
+ hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
|
| 616 |
+
+ else:
|
| 617 |
+
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
| 618 |
+
+
|
| 619 |
+
+ # 7. Output layers
|
| 620 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 621 |
+
output = self.proj_out(hidden_states)
|
| 622 |
+
|
| 623 |
+
+ if kv_cache_mode == "extract":
|
| 624 |
+
+ if not return_dict:
|
| 625 |
+
+ return (output,), kv_cache
|
| 626 |
+
+ return Transformer2DModelOutput(sample=output), kv_cache
|
| 627 |
+
+
|
| 628 |
+
if not return_dict:
|
| 629 |
+
return (output,)
|
| 630 |
+
|
| 631 |
+
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
|
| 632 |
+
index 800703533..b9596f4b7 100644
|
| 633 |
+
--- a/src/diffusers/pipelines/__init__.py
|
| 634 |
+
+++ b/src/diffusers/pipelines/__init__.py
|
| 635 |
+
@@ -129,7 +129,7 @@ else:
|
| 636 |
+
]
|
| 637 |
+
_import_structure["bria"] = ["BriaPipeline"]
|
| 638 |
+
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
| 639 |
+
- _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
| 640 |
+
+ _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
|
| 641 |
+
_import_structure["flux"] = [
|
| 642 |
+
"FluxControlPipeline",
|
| 643 |
+
"FluxControlInpaintPipeline",
|
| 644 |
+
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 645 |
+
FluxPriorReduxPipeline,
|
| 646 |
+
ReduxImageEncoder,
|
| 647 |
+
)
|
| 648 |
+
- from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
| 649 |
+
+ from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
|
| 650 |
+
from .glm_image import GlmImagePipeline
|
| 651 |
+
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
| 652 |
+
from .hidream_image import HiDreamImagePipeline
|
| 653 |
+
diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py
|
| 654 |
+
index f6e1d5206..52a8f464b 100644
|
| 655 |
+
--- a/src/diffusers/pipelines/flux2/__init__.py
|
| 656 |
+
+++ b/src/diffusers/pipelines/flux2/__init__.py
|
| 657 |
+
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
| 658 |
+
else:
|
| 659 |
+
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
| 660 |
+
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
| 661 |
+
+ _import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
|
| 662 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 663 |
+
try:
|
| 664 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 665 |
+
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 666 |
+
else:
|
| 667 |
+
from .pipeline_flux2 import Flux2Pipeline
|
| 668 |
+
from .pipeline_flux2_klein import Flux2KleinPipeline
|
| 669 |
+
+ from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
| 670 |
+
else:
|
| 671 |
+
import sys
|
| 672 |
+
|
| 673 |
+
diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
|
| 674 |
+
new file mode 100644
|
| 675 |
+
index 000000000..cac29f621
|
| 676 |
+
--- /dev/null
|
| 677 |
+
+++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
|
| 678 |
+
@@ -0,0 +1,887 @@
|
| 679 |
+
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 680 |
+
+#
|
| 681 |
+
+# Licensed under the Apache License, Version 2.0 (the "License");
|
| 682 |
+
+# you may not use this file except in compliance with the License.
|
| 683 |
+
+# You may obtain a copy of the License at
|
| 684 |
+
+#
|
| 685 |
+
+# http://www.apache.org/licenses/LICENSE-2.0
|
| 686 |
+
+#
|
| 687 |
+
+# Unless required by applicable law or agreed to in writing, software
|
| 688 |
+
+# distributed under the License is distributed on an "AS IS" BASIS,
|
| 689 |
+
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 690 |
+
+# See the License for the specific language governing permissions and
|
| 691 |
+
+# limitations under the License.
|
| 692 |
+
+
|
| 693 |
+
+import inspect
|
| 694 |
+
+from typing import Any, Callable
|
| 695 |
+
+
|
| 696 |
+
+import numpy as np
|
| 697 |
+
+import PIL
|
| 698 |
+
+import torch
|
| 699 |
+
+from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
| 700 |
+
+
|
| 701 |
+
+from ...loaders import Flux2LoraLoaderMixin
|
| 702 |
+
+from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
| 703 |
+
+from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
|
| 704 |
+
+from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 705 |
+
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
| 706 |
+
+from ...utils.torch_utils import randn_tensor
|
| 707 |
+
+from ..pipeline_utils import DiffusionPipeline
|
| 708 |
+
+from .image_processor import Flux2ImageProcessor
|
| 709 |
+
+from .pipeline_output import Flux2PipelineOutput
|
| 710 |
+
+
|
| 711 |
+
+
|
| 712 |
+
+if is_torch_xla_available():
|
| 713 |
+
+ import torch_xla.core.xla_model as xm
|
| 714 |
+
+
|
| 715 |
+
+ XLA_AVAILABLE = True
|
| 716 |
+
+else:
|
| 717 |
+
+ XLA_AVAILABLE = False
|
| 718 |
+
+
|
| 719 |
+
+
|
| 720 |
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 721 |
+
+
|
| 722 |
+
+EXAMPLE_DOC_STRING = """
|
| 723 |
+
+ Examples:
|
| 724 |
+
+ ```py
|
| 725 |
+
+ >>> import torch
|
| 726 |
+
+ >>> from PIL import Image
|
| 727 |
+
+ >>> from diffusers import Flux2KleinKVPipeline
|
| 728 |
+
+
|
| 729 |
+
+ >>> pipe = Flux2KleinKVPipeline.from_pretrained(
|
| 730 |
+
+ ... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
|
| 731 |
+
+ ... )
|
| 732 |
+
+ >>> pipe.to("cuda")
|
| 733 |
+
+ >>> ref_image = Image.open("reference.png")
|
| 734 |
+
+ >>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
|
| 735 |
+
+ >>> image.save("flux2_kv_output.png")
|
| 736 |
+
+ ```
|
| 737 |
+
+"""
|
| 738 |
+
+
|
| 739 |
+
+
|
| 740 |
+
+# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
| 741 |
+
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| 742 |
+
+ a1, b1 = 8.73809524e-05, 1.89833333
|
| 743 |
+
+ a2, b2 = 0.00016927, 0.45666666
|
| 744 |
+
+
|
| 745 |
+
+ if image_seq_len > 4300:
|
| 746 |
+
+ mu = a2 * image_seq_len + b2
|
| 747 |
+
+ return float(mu)
|
| 748 |
+
+
|
| 749 |
+
+ m_200 = a2 * image_seq_len + b2
|
| 750 |
+
+ m_10 = a1 * image_seq_len + b1
|
| 751 |
+
+
|
| 752 |
+
+ a = (m_200 - m_10) / 190.0
|
| 753 |
+
+ b = m_200 - 200.0 * a
|
| 754 |
+
+ mu = a * num_steps + b
|
| 755 |
+
+
|
| 756 |
+
+ return float(mu)
|
| 757 |
+
+
|
| 758 |
+
+
|
| 759 |
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 760 |
+
+def retrieve_timesteps(
|
| 761 |
+
+ scheduler,
|
| 762 |
+
+ num_inference_steps: int | None = None,
|
| 763 |
+
+ device: str | torch.device | None = None,
|
| 764 |
+
+ timesteps: list[int] | None = None,
|
| 765 |
+
+ sigmas: list[float] | None = None,
|
| 766 |
+
+ **kwargs,
|
| 767 |
+
+):
|
| 768 |
+
+ r"""
|
| 769 |
+
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 770 |
+
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 771 |
+
+
|
| 772 |
+
+ Args:
|
| 773 |
+
+ scheduler (`SchedulerMixin`):
|
| 774 |
+
+ The scheduler to get timesteps from.
|
| 775 |
+
+ num_inference_steps (`int`):
|
| 776 |
+
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 777 |
+
+ must be `None`.
|
| 778 |
+
+ device (`str` or `torch.device`, *optional*):
|
| 779 |
+
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 780 |
+
+ timesteps (`list[int]`, *optional*):
|
| 781 |
+
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 782 |
+
+ `num_inference_steps` and `sigmas` must be `None`.
|
| 783 |
+
+ sigmas (`list[float]`, *optional*):
|
| 784 |
+
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 785 |
+
+ `num_inference_steps` and `timesteps` must be `None`.
|
| 786 |
+
+
|
| 787 |
+
+ Returns:
|
| 788 |
+
+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 789 |
+
+ second element is the number of inference steps.
|
| 790 |
+
+ """
|
| 791 |
+
+ if timesteps is not None and sigmas is not None:
|
| 792 |
+
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 793 |
+
+ if timesteps is not None:
|
| 794 |
+
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 795 |
+
+ if not accepts_timesteps:
|
| 796 |
+
+ raise ValueError(
|
| 797 |
+
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 798 |
+
+ f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 799 |
+
+ )
|
| 800 |
+
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 801 |
+
+ timesteps = scheduler.timesteps
|
| 802 |
+
+ num_inference_steps = len(timesteps)
|
| 803 |
+
+ elif sigmas is not None:
|
| 804 |
+
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 805 |
+
+ if not accept_sigmas:
|
| 806 |
+
+ raise ValueError(
|
| 807 |
+
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 808 |
+
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 809 |
+
+ )
|
| 810 |
+
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 811 |
+
+ timesteps = scheduler.timesteps
|
| 812 |
+
+ num_inference_steps = len(timesteps)
|
| 813 |
+
+ else:
|
| 814 |
+
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 815 |
+
+ timesteps = scheduler.timesteps
|
| 816 |
+
+ return timesteps, num_inference_steps
|
| 817 |
+
+
|
| 818 |
+
+
|
| 819 |
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 820 |
+
+def retrieve_latents(
|
| 821 |
+
+ encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
| 822 |
+
+):
|
| 823 |
+
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 824 |
+
+ return encoder_output.latent_dist.sample(generator)
|
| 825 |
+
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 826 |
+
+ return encoder_output.latent_dist.mode()
|
| 827 |
+
+ elif hasattr(encoder_output, "latents"):
|
| 828 |
+
+ return encoder_output.latents
|
| 829 |
+
+ else:
|
| 830 |
+
+ raise AttributeError("Could not access latents of provided encoder_output")
|
| 831 |
+
+
|
| 832 |
+
+
|
| 833 |
+
+class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
| 834 |
+
+ r"""
|
| 835 |
+
+ The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
|
| 836 |
+
+
|
| 837 |
+
+ On the first denoising step, reference image tokens are included in the forward pass and their attention
|
| 838 |
+
+ K/V projections are cached. On subsequent steps, the cached K/V are reused without recomputing,
|
| 839 |
+
+ providing faster inference when using reference images.
|
| 840 |
+
+
|
| 841 |
+
+ Reference:
|
| 842 |
+
+ [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
| 843 |
+
+
|
| 844 |
+
+ Args:
|
| 845 |
+
+ transformer ([`Flux2Transformer2DModel`]):
|
| 846 |
+
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 847 |
+
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 848 |
+
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 849 |
+
+ vae ([`AutoencoderKLFlux2`]):
|
| 850 |
+
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 851 |
+
+ text_encoder ([`Qwen3ForCausalLM`]):
|
| 852 |
+
+ [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
| 853 |
+
+ tokenizer (`Qwen2TokenizerFast`):
|
| 854 |
+
+ Tokenizer of class
|
| 855 |
+
+ [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
| 856 |
+
+ """
|
| 857 |
+
+
|
| 858 |
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 859 |
+
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 860 |
+
+
|
| 861 |
+
+ def __init__(
|
| 862 |
+
+ self,
|
| 863 |
+
+ scheduler: FlowMatchEulerDiscreteScheduler,
|
| 864 |
+
+ vae: AutoencoderKLFlux2,
|
| 865 |
+
+ text_encoder: Qwen3ForCausalLM,
|
| 866 |
+
+ tokenizer: Qwen2TokenizerFast,
|
| 867 |
+
+ transformer: Flux2Transformer2DModel,
|
| 868 |
+
+ is_distilled: bool = True,
|
| 869 |
+
+ ):
|
| 870 |
+
+ super().__init__()
|
| 871 |
+
+
|
| 872 |
+
+ self.register_modules(
|
| 873 |
+
+ vae=vae,
|
| 874 |
+
+ text_encoder=text_encoder,
|
| 875 |
+
+ tokenizer=tokenizer,
|
| 876 |
+
+ scheduler=scheduler,
|
| 877 |
+
+ transformer=transformer,
|
| 878 |
+
+ )
|
| 879 |
+
+
|
| 880 |
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 881 |
+
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 882 |
+
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 883 |
+
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 884 |
+
+ self.tokenizer_max_length = 512
|
| 885 |
+
+ self.default_sample_size = 128
|
| 886 |
+
+
|
| 887 |
+
+ # Set KV-cache-aware attention processors
|
| 888 |
+
+ self._set_kv_attn_processors()
|
| 889 |
+
+
|
| 890 |
+
+ @staticmethod
|
| 891 |
+
+ def _get_qwen3_prompt_embeds(
|
| 892 |
+
+ text_encoder: Qwen3ForCausalLM,
|
| 893 |
+
+ tokenizer: Qwen2TokenizerFast,
|
| 894 |
+
+ prompt: str | list[str],
|
| 895 |
+
+ dtype: torch.dtype | None = None,
|
| 896 |
+
+ device: torch.device | None = None,
|
| 897 |
+
+ max_sequence_length: int = 512,
|
| 898 |
+
+ hidden_states_layers: list[int] = (9, 18, 27),
|
| 899 |
+
+ ):
|
| 900 |
+
+ dtype = text_encoder.dtype if dtype is None else dtype
|
| 901 |
+
+ device = text_encoder.device if device is None else device
|
| 902 |
+
+
|
| 903 |
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 904 |
+
+
|
| 905 |
+
+ all_input_ids = []
|
| 906 |
+
+ all_attention_masks = []
|
| 907 |
+
+
|
| 908 |
+
+ for single_prompt in prompt:
|
| 909 |
+
+ messages = [{"role": "user", "content": single_prompt}]
|
| 910 |
+
+ text = tokenizer.apply_chat_template(
|
| 911 |
+
+ messages,
|
| 912 |
+
+ tokenize=False,
|
| 913 |
+
+ add_generation_prompt=True,
|
| 914 |
+
+ enable_thinking=False,
|
| 915 |
+
+ )
|
| 916 |
+
+ inputs = tokenizer(
|
| 917 |
+
+ text,
|
| 918 |
+
+ return_tensors="pt",
|
| 919 |
+
+ padding="max_length",
|
| 920 |
+
+ truncation=True,
|
| 921 |
+
+ max_length=max_sequence_length,
|
| 922 |
+
+ )
|
| 923 |
+
+
|
| 924 |
+
+ all_input_ids.append(inputs["input_ids"])
|
| 925 |
+
+ all_attention_masks.append(inputs["attention_mask"])
|
| 926 |
+
+
|
| 927 |
+
+ input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
| 928 |
+
+ attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
| 929 |
+
+
|
| 930 |
+
+ # Forward pass through the model
|
| 931 |
+
+ output = text_encoder(
|
| 932 |
+
+ input_ids=input_ids,
|
| 933 |
+
+ attention_mask=attention_mask,
|
| 934 |
+
+ output_hidden_states=True,
|
| 935 |
+
+ use_cache=False,
|
| 936 |
+
+ )
|
| 937 |
+
+
|
| 938 |
+
+ # Only use outputs from intermediate layers and stack them
|
| 939 |
+
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
| 940 |
+
+ out = out.to(dtype=dtype, device=device)
|
| 941 |
+
+
|
| 942 |
+
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
|
| 943 |
+
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
| 944 |
+
+
|
| 945 |
+
+ return prompt_embeds
|
| 946 |
+
+
|
| 947 |
+
+ @staticmethod
|
| 948 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
| 949 |
+
+ def _prepare_text_ids(
|
| 950 |
+
+ x: torch.Tensor, # (B, L, D) or (L, D)
|
| 951 |
+
+ t_coord: torch.Tensor | None = None,
|
| 952 |
+
+ ):
|
| 953 |
+
+ B, L, _ = x.shape
|
| 954 |
+
+ out_ids = []
|
| 955 |
+
+
|
| 956 |
+
+ for i in range(B):
|
| 957 |
+
+ t = torch.arange(1) if t_coord is None else t_coord[i]
|
| 958 |
+
+ h = torch.arange(1)
|
| 959 |
+
+ w = torch.arange(1)
|
| 960 |
+
+ l = torch.arange(L)
|
| 961 |
+
+
|
| 962 |
+
+ coords = torch.cartesian_prod(t, h, w, l)
|
| 963 |
+
+ out_ids.append(coords)
|
| 964 |
+
+
|
| 965 |
+
+ return torch.stack(out_ids)
|
| 966 |
+
+
|
| 967 |
+
+ @staticmethod
|
| 968 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
| 969 |
+
+ def _prepare_latent_ids(
|
| 970 |
+
+ latents: torch.Tensor, # (B, C, H, W)
|
| 971 |
+
+ ):
|
| 972 |
+
+ r"""
|
| 973 |
+
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
| 974 |
+
+
|
| 975 |
+
+ Args:
|
| 976 |
+
+ latents (torch.Tensor):
|
| 977 |
+
+ Latent tensor of shape (B, C, H, W)
|
| 978 |
+
+
|
| 979 |
+
+ Returns:
|
| 980 |
+
+ torch.Tensor:
|
| 981 |
+
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
| 982 |
+
+ H=[0..H-1], W=[0..W-1], L=0
|
| 983 |
+
+ """
|
| 984 |
+
+
|
| 985 |
+
+ batch_size, _, height, width = latents.shape
|
| 986 |
+
+
|
| 987 |
+
+ t = torch.arange(1) # [0] - time dimension
|
| 988 |
+
+ h = torch.arange(height)
|
| 989 |
+
+ w = torch.arange(width)
|
| 990 |
+
+ l = torch.arange(1) # [0] - layer dimension
|
| 991 |
+
+
|
| 992 |
+
+ # Create position IDs: (H*W, 4)
|
| 993 |
+
+ latent_ids = torch.cartesian_prod(t, h, w, l)
|
| 994 |
+
+
|
| 995 |
+
+ # Expand to batch: (B, H*W, 4)
|
| 996 |
+
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
| 997 |
+
+
|
| 998 |
+
+ return latent_ids
|
| 999 |
+
+
|
| 1000 |
+
+ @staticmethod
|
| 1001 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
| 1002 |
+
+ def _prepare_image_ids(
|
| 1003 |
+
+ image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
| 1004 |
+
+ scale: int = 10,
|
| 1005 |
+
+ ):
|
| 1006 |
+
+ r"""
|
| 1007 |
+
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
| 1008 |
+
+
|
| 1009 |
+
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
|
| 1010 |
+
+ dimensions.
|
| 1011 |
+
+
|
| 1012 |
+
+ Args:
|
| 1013 |
+
+ image_latents (list[torch.Tensor]):
|
| 1014 |
+
+ A list of image latent feature tensors, typically of shape (C, H, W).
|
| 1015 |
+
+ scale (int, optional):
|
| 1016 |
+
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
| 1017 |
+
+ latent is: 'scale + scale * i'. Defaults to 10.
|
| 1018 |
+
+
|
| 1019 |
+
+ Returns:
|
| 1020 |
+
+ torch.Tensor:
|
| 1021 |
+
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
| 1022 |
+
+ input latents.
|
| 1023 |
+
+
|
| 1024 |
+
+ Coordinate Components (Dimension 4):
|
| 1025 |
+
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
|
| 1026 |
+
+ - H (Height): The row index within that latent image.
|
| 1027 |
+
+ - W (Width): The column index within that latent image.
|
| 1028 |
+
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
| 1029 |
+
+ """
|
| 1030 |
+
+
|
| 1031 |
+
+ if not isinstance(image_latents, list):
|
| 1032 |
+
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
| 1033 |
+
+
|
| 1034 |
+
+ # create time offset for each reference image
|
| 1035 |
+
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
| 1036 |
+
+ t_coords = [t.view(-1) for t in t_coords]
|
| 1037 |
+
+
|
| 1038 |
+
+ image_latent_ids = []
|
| 1039 |
+
+ for x, t in zip(image_latents, t_coords):
|
| 1040 |
+
+ x = x.squeeze(0)
|
| 1041 |
+
+ _, height, width = x.shape
|
| 1042 |
+
+
|
| 1043 |
+
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
| 1044 |
+
+ image_latent_ids.append(x_ids)
|
| 1045 |
+
+
|
| 1046 |
+
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
| 1047 |
+
+ image_latent_ids = image_latent_ids.unsqueeze(0)
|
| 1048 |
+
+
|
| 1049 |
+
+ return image_latent_ids
|
| 1050 |
+
+
|
| 1051 |
+
+ @staticmethod
|
| 1052 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
| 1053 |
+
+ def _patchify_latents(latents):
|
| 1054 |
+
+ batch_size, num_channels_latents, height, width = latents.shape
|
| 1055 |
+
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 1056 |
+
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
|
| 1057 |
+
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
| 1058 |
+
+ return latents
|
| 1059 |
+
+
|
| 1060 |
+
+ @staticmethod
|
| 1061 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
| 1062 |
+
+ def _unpatchify_latents(latents):
|
| 1063 |
+
+ batch_size, num_channels_latents, height, width = latents.shape
|
| 1064 |
+
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
| 1065 |
+
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
|
| 1066 |
+
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
| 1067 |
+
+ return latents
|
| 1068 |
+
+
|
| 1069 |
+
+ @staticmethod
|
| 1070 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
| 1071 |
+
+ def _pack_latents(latents):
|
| 1072 |
+
+ """
|
| 1073 |
+
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
| 1074 |
+
+ """
|
| 1075 |
+
+
|
| 1076 |
+
+ batch_size, num_channels, height, width = latents.shape
|
| 1077 |
+
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
| 1078 |
+
+
|
| 1079 |
+
+ return latents
|
| 1080 |
+
+
|
| 1081 |
+
+ @staticmethod
|
| 1082 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
| 1083 |
+
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
| 1084 |
+
+ """
|
| 1085 |
+
+ using position ids to scatter tokens into place
|
| 1086 |
+
+ """
|
| 1087 |
+
+ x_list = []
|
| 1088 |
+
+ for data, pos in zip(x, x_ids):
|
| 1089 |
+
+ _, ch = data.shape # noqa: F841
|
| 1090 |
+
+ h_ids = pos[:, 1].to(torch.int64)
|
| 1091 |
+
+ w_ids = pos[:, 2].to(torch.int64)
|
| 1092 |
+
+
|
| 1093 |
+
+ h = torch.max(h_ids) + 1
|
| 1094 |
+
+ w = torch.max(w_ids) + 1
|
| 1095 |
+
+
|
| 1096 |
+
+ flat_ids = h_ids * w + w_ids
|
| 1097 |
+
+
|
| 1098 |
+
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
| 1099 |
+
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
| 1100 |
+
+
|
| 1101 |
+
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
| 1102 |
+
+
|
| 1103 |
+
+ out = out.view(h, w, ch).permute(2, 0, 1)
|
| 1104 |
+
+ x_list.append(out)
|
| 1105 |
+
+
|
| 1106 |
+
+ return torch.stack(x_list, dim=0)
|
| 1107 |
+
+
|
| 1108 |
+
+ def _set_kv_attn_processors(self):
|
| 1109 |
+
+ """Replace default attention processors with KV-cache-aware variants."""
|
| 1110 |
+
+ for block in self.transformer.transformer_blocks:
|
| 1111 |
+
+ block.attn.set_processor(Flux2KVAttnProcessor())
|
| 1112 |
+
+ for block in self.transformer.single_transformer_blocks:
|
| 1113 |
+
+ block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
| 1114 |
+
+
|
| 1115 |
+
+ def encode_prompt(
|
| 1116 |
+
+ self,
|
| 1117 |
+
+ prompt: str | list[str],
|
| 1118 |
+
+ device: torch.device | None = None,
|
| 1119 |
+
+ num_images_per_prompt: int = 1,
|
| 1120 |
+
+ prompt_embeds: torch.Tensor | None = None,
|
| 1121 |
+
+ max_sequence_length: int = 512,
|
| 1122 |
+
+ text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
| 1123 |
+
+ ):
|
| 1124 |
+
+ device = device or self._execution_device
|
| 1125 |
+
+
|
| 1126 |
+
+ if prompt is None:
|
| 1127 |
+
+ prompt = ""
|
| 1128 |
+
+
|
| 1129 |
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 1130 |
+
+
|
| 1131 |
+
+ if prompt_embeds is None:
|
| 1132 |
+
+ prompt_embeds = self._get_qwen3_prompt_embeds(
|
| 1133 |
+
+ text_encoder=self.text_encoder,
|
| 1134 |
+
+ tokenizer=self.tokenizer,
|
| 1135 |
+
+ prompt=prompt,
|
| 1136 |
+
+ device=device,
|
| 1137 |
+
+ max_sequence_length=max_sequence_length,
|
| 1138 |
+
+ hidden_states_layers=text_encoder_out_layers,
|
| 1139 |
+
+ )
|
| 1140 |
+
+
|
| 1141 |
+
+ batch_size, seq_len, _ = prompt_embeds.shape
|
| 1142 |
+
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 1143 |
+
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 1144 |
+
+
|
| 1145 |
+
+ text_ids = self._prepare_text_ids(prompt_embeds)
|
| 1146 |
+
+ text_ids = text_ids.to(device)
|
| 1147 |
+
+ return prompt_embeds, text_ids
|
| 1148 |
+
+
|
| 1149 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
| 1150 |
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 1151 |
+
+ if image.ndim != 4:
|
| 1152 |
+
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
| 1153 |
+
+
|
| 1154 |
+
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 1155 |
+
+ image_latents = self._patchify_latents(image_latents)
|
| 1156 |
+
+
|
| 1157 |
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
| 1158 |
+
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
| 1159 |
+
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
| 1160 |
+
+
|
| 1161 |
+
+ return image_latents
|
| 1162 |
+
+
|
| 1163 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
| 1164 |
+
+ def prepare_latents(
|
| 1165 |
+
+ self,
|
| 1166 |
+
+ batch_size,
|
| 1167 |
+
+ num_latents_channels,
|
| 1168 |
+
+ height,
|
| 1169 |
+
+ width,
|
| 1170 |
+
+ dtype,
|
| 1171 |
+
+ device,
|
| 1172 |
+
+ generator: torch.Generator,
|
| 1173 |
+
+ latents: torch.Tensor | None = None,
|
| 1174 |
+
+ ):
|
| 1175 |
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
|
| 1176 |
+
+ # latent height and width to be divisible by 2.
|
| 1177 |
+
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 1178 |
+
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 1179 |
+
+
|
| 1180 |
+
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
| 1181 |
+
+ if isinstance(generator, list) and len(generator) != batch_size:
|
| 1182 |
+
+ raise ValueError(
|
| 1183 |
+
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 1184 |
+
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 1185 |
+
+ )
|
| 1186 |
+
+ if latents is None:
|
| 1187 |
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 1188 |
+
+ else:
|
| 1189 |
+
+ latents = latents.to(device=device, dtype=dtype)
|
| 1190 |
+
+
|
| 1191 |
+
+ latent_ids = self._prepare_latent_ids(latents)
|
| 1192 |
+
+ latent_ids = latent_ids.to(device)
|
| 1193 |
+
+
|
| 1194 |
+
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
| 1195 |
+
+ return latents, latent_ids
|
| 1196 |
+
+
|
| 1197 |
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
| 1198 |
+
+ def prepare_image_latents(
|
| 1199 |
+
+ self,
|
| 1200 |
+
+ images: list[torch.Tensor],
|
| 1201 |
+
+ batch_size,
|
| 1202 |
+
+ generator: torch.Generator,
|
| 1203 |
+
+ device,
|
| 1204 |
+
+ dtype,
|
| 1205 |
+
+ ):
|
| 1206 |
+
+ image_latents = []
|
| 1207 |
+
+ for image in images:
|
| 1208 |
+
+ image = image.to(device=device, dtype=dtype)
|
| 1209 |
+
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
| 1210 |
+
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
| 1211 |
+
+
|
| 1212 |
+
+ image_latent_ids = self._prepare_image_ids(image_latents)
|
| 1213 |
+
+
|
| 1214 |
+
+ # Pack each latent and concatenate
|
| 1215 |
+
+ packed_latents = []
|
| 1216 |
+
+ for latent in image_latents:
|
| 1217 |
+
+ # latent: (1, 128, 32, 32)
|
| 1218 |
+
+ packed = self._pack_latents(latent) # (1, 1024, 128)
|
| 1219 |
+
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
| 1220 |
+
+ packed_latents.append(packed)
|
| 1221 |
+
+
|
| 1222 |
+
+ # Concatenate all reference tokens along sequence dimension
|
| 1223 |
+
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
| 1224 |
+
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
| 1225 |
+
+
|
| 1226 |
+
+ image_latents = image_latents.repeat(batch_size, 1, 1)
|
| 1227 |
+
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
| 1228 |
+
+ image_latent_ids = image_latent_ids.to(device)
|
| 1229 |
+
+
|
| 1230 |
+
+ return image_latents, image_latent_ids
|
| 1231 |
+
+
|
| 1232 |
+
+ def check_inputs(
|
| 1233 |
+
+ self,
|
| 1234 |
+
+ prompt,
|
| 1235 |
+
+ height,
|
| 1236 |
+
+ width,
|
| 1237 |
+
+ prompt_embeds=None,
|
| 1238 |
+
+ callback_on_step_end_tensor_inputs=None,
|
| 1239 |
+
+ ):
|
| 1240 |
+
+ if (
|
| 1241 |
+
+ height is not None
|
| 1242 |
+
+ and height % (self.vae_scale_factor * 2) != 0
|
| 1243 |
+
+ or width is not None
|
| 1244 |
+
+ and width % (self.vae_scale_factor * 2) != 0
|
| 1245 |
+
+ ):
|
| 1246 |
+
+ logger.warning(
|
| 1247 |
+
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 1248 |
+
+ )
|
| 1249 |
+
+
|
| 1250 |
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
|
| 1251 |
+
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 1252 |
+
+ ):
|
| 1253 |
+
+ raise ValueError(
|
| 1254 |
+
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 1255 |
+
+ )
|
| 1256 |
+
+
|
| 1257 |
+
+ if prompt is not None and prompt_embeds is not None:
|
| 1258 |
+
+ raise ValueError(
|
| 1259 |
+
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 1260 |
+
+ " only forward one of the two."
|
| 1261 |
+
+ )
|
| 1262 |
+
+ elif prompt is None and prompt_embeds is None:
|
| 1263 |
+
+ raise ValueError(
|
| 1264 |
+
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 1265 |
+
+ )
|
| 1266 |
+
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 1267 |
+
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 1268 |
+
+
|
| 1269 |
+
+ @property
|
| 1270 |
+
+ def attention_kwargs(self):
|
| 1271 |
+
+ return self._attention_kwargs
|
| 1272 |
+
+
|
| 1273 |
+
+ @property
|
| 1274 |
+
+ def num_timesteps(self):
|
| 1275 |
+
+ return self._num_timesteps
|
| 1276 |
+
+
|
| 1277 |
+
+ @property
|
| 1278 |
+
+ def current_timestep(self):
|
| 1279 |
+
+ return self._current_timestep
|
| 1280 |
+
+
|
| 1281 |
+
+ @property
|
| 1282 |
+
+ def interrupt(self):
|
| 1283 |
+
+ return self._interrupt
|
| 1284 |
+
+
|
| 1285 |
+
+ @torch.no_grad()
|
| 1286 |
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 1287 |
+
+ def __call__(
|
| 1288 |
+
+ self,
|
| 1289 |
+
+ image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
|
| 1290 |
+
+ prompt: str | list[str] = None,
|
| 1291 |
+
+ height: int | None = None,
|
| 1292 |
+
+ width: int | None = None,
|
| 1293 |
+
+ num_inference_steps: int = 4,
|
| 1294 |
+
+ sigmas: list[float] | None = None,
|
| 1295 |
+
+ num_images_per_prompt: int = 1,
|
| 1296 |
+
+ generator: torch.Generator | list[torch.Generator] | None = None,
|
| 1297 |
+
+ latents: torch.Tensor | None = None,
|
| 1298 |
+
+ prompt_embeds: torch.Tensor | None = None,
|
| 1299 |
+
+ output_type: str = "pil",
|
| 1300 |
+
+ return_dict: bool = True,
|
| 1301 |
+
+ attention_kwargs: dict[str, Any] | None = None,
|
| 1302 |
+
+ callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
| 1303 |
+
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
| 1304 |
+
+ max_sequence_length: int = 512,
|
| 1305 |
+
+ text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
| 1306 |
+
+ ):
|
| 1307 |
+
+ r"""
|
| 1308 |
+
+ Function invoked when calling the pipeline for generation.
|
| 1309 |
+
+
|
| 1310 |
+
+ Args:
|
| 1311 |
+
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
|
| 1312 |
+
+ Reference image(s) for conditioning. On the first denoising step, reference tokens are included
|
| 1313 |
+
+ in the forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are
|
| 1314 |
+
+ reused without recomputing.
|
| 1315 |
+
+ prompt (`str` or `List[str]`, *optional*):
|
| 1316 |
+
+ The prompt or prompts to guide the image generation.
|
| 1317 |
+
+ height (`int`, *optional*):
|
| 1318 |
+
+ The height in pixels of the generated image.
|
| 1319 |
+
+ width (`int`, *optional*):
|
| 1320 |
+
+ The width in pixels of the generated image.
|
| 1321 |
+
+ num_inference_steps (`int`, *optional*, defaults to 4):
|
| 1322 |
+
+ The number of denoising steps.
|
| 1323 |
+
+ sigmas (`List[float]`, *optional*):
|
| 1324 |
+
+ Custom sigmas for the denoising schedule.
|
| 1325 |
+
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1326 |
+
+ The number of images to generate per prompt.
|
| 1327 |
+
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 1328 |
+
+ Generator(s) for deterministic generation.
|
| 1329 |
+
+ latents (`torch.Tensor`, *optional*):
|
| 1330 |
+
+ Pre-generated noisy latents.
|
| 1331 |
+
+ prompt_embeds (`torch.Tensor`, *optional*):
|
| 1332 |
+
+ Pre-generated text embeddings.
|
| 1333 |
+
+ output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1334 |
+
+ Output format: `"pil"` or `"np"`.
|
| 1335 |
+
+ return_dict (`bool`, *optional*, defaults to `True`):
|
| 1336 |
+
+ Whether to return a `Flux2PipelineOutput` or a plain tuple.
|
| 1337 |
+
+ attention_kwargs (`dict`, *optional*):
|
| 1338 |
+
+ Extra kwargs passed to attention processors.
|
| 1339 |
+
+ callback_on_step_end (`Callable`, *optional*):
|
| 1340 |
+
+ Callback function called at the end of each denoising step.
|
| 1341 |
+
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1342 |
+
+ Tensor inputs for the callback function.
|
| 1343 |
+
+ max_sequence_length (`int`, defaults to 512):
|
| 1344 |
+
+ Maximum sequence length for the prompt.
|
| 1345 |
+
+ text_encoder_out_layers (`tuple[int]`):
|
| 1346 |
+
+ Layer indices for text encoder hidden state extraction.
|
| 1347 |
+
+
|
| 1348 |
+
+ Examples:
|
| 1349 |
+
+
|
| 1350 |
+
+ Returns:
|
| 1351 |
+
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
|
| 1352 |
+
+ """
|
| 1353 |
+
+
|
| 1354 |
+
+ # 1. Check inputs
|
| 1355 |
+
+ self.check_inputs(
|
| 1356 |
+
+ prompt=prompt,
|
| 1357 |
+
+ height=height,
|
| 1358 |
+
+ width=width,
|
| 1359 |
+
+ prompt_embeds=prompt_embeds,
|
| 1360 |
+
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1361 |
+
+ )
|
| 1362 |
+
+
|
| 1363 |
+
+ self._attention_kwargs = attention_kwargs
|
| 1364 |
+
+ self._current_timestep = None
|
| 1365 |
+
+ self._interrupt = False
|
| 1366 |
+
+
|
| 1367 |
+
+ # 2. Define call parameters
|
| 1368 |
+
+ if prompt is not None and isinstance(prompt, str):
|
| 1369 |
+
+ batch_size = 1
|
| 1370 |
+
+ elif prompt is not None and isinstance(prompt, list):
|
| 1371 |
+
+ batch_size = len(prompt)
|
| 1372 |
+
+ else:
|
| 1373 |
+
+ batch_size = prompt_embeds.shape[0]
|
| 1374 |
+
+
|
| 1375 |
+
+ device = self._execution_device
|
| 1376 |
+
+
|
| 1377 |
+
+ # 3. prepare text embeddings
|
| 1378 |
+
+ prompt_embeds, text_ids = self.encode_prompt(
|
| 1379 |
+
+ prompt=prompt,
|
| 1380 |
+
+ prompt_embeds=prompt_embeds,
|
| 1381 |
+
+ device=device,
|
| 1382 |
+
+ num_images_per_prompt=num_images_per_prompt,
|
| 1383 |
+
+ max_sequence_length=max_sequence_length,
|
| 1384 |
+
+ text_encoder_out_layers=text_encoder_out_layers,
|
| 1385 |
+
+ )
|
| 1386 |
+
+
|
| 1387 |
+
+ # 4. process images
|
| 1388 |
+
+ if image is not None and not isinstance(image, list):
|
| 1389 |
+
+ image = [image]
|
| 1390 |
+
+
|
| 1391 |
+
+ condition_images = None
|
| 1392 |
+
+ if image is not None:
|
| 1393 |
+
+ for img in image:
|
| 1394 |
+
+ self.image_processor.check_image_input(img)
|
| 1395 |
+
+
|
| 1396 |
+
+ condition_images = []
|
| 1397 |
+
+ for img in image:
|
| 1398 |
+
+ image_width, image_height = img.size
|
| 1399 |
+
+ if image_width * image_height > 1024 * 1024:
|
| 1400 |
+
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
| 1401 |
+
+ image_width, image_height = img.size
|
| 1402 |
+
+
|
| 1403 |
+
+ multiple_of = self.vae_scale_factor * 2
|
| 1404 |
+
+ image_width = (image_width // multiple_of) * multiple_of
|
| 1405 |
+
+ image_height = (image_height // multiple_of) * multiple_of
|
| 1406 |
+
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
| 1407 |
+
+ condition_images.append(img)
|
| 1408 |
+
+ height = height or image_height
|
| 1409 |
+
+ width = width or image_width
|
| 1410 |
+
+
|
| 1411 |
+
+ height = height or self.default_sample_size * self.vae_scale_factor
|
| 1412 |
+
+ width = width or self.default_sample_size * self.vae_scale_factor
|
| 1413 |
+
+
|
| 1414 |
+
+ # 5. prepare latent variables
|
| 1415 |
+
+ num_channels_latents = self.transformer.config.in_channels // 4
|
| 1416 |
+
+ latents, latent_ids = self.prepare_latents(
|
| 1417 |
+
+ batch_size=batch_size * num_images_per_prompt,
|
| 1418 |
+
+ num_latents_channels=num_channels_latents,
|
| 1419 |
+
+ height=height,
|
| 1420 |
+
+ width=width,
|
| 1421 |
+
+ dtype=prompt_embeds.dtype,
|
| 1422 |
+
+ device=device,
|
| 1423 |
+
+ generator=generator,
|
| 1424 |
+
+ latents=latents,
|
| 1425 |
+
+ )
|
| 1426 |
+
+
|
| 1427 |
+
+ image_latents = None
|
| 1428 |
+
+ image_latent_ids = None
|
| 1429 |
+
+ if condition_images is not None:
|
| 1430 |
+
+ image_latents, image_latent_ids = self.prepare_image_latents(
|
| 1431 |
+
+ images=condition_images,
|
| 1432 |
+
+ batch_size=batch_size * num_images_per_prompt,
|
| 1433 |
+
+ generator=generator,
|
| 1434 |
+
+ device=device,
|
| 1435 |
+
+ dtype=self.vae.dtype,
|
| 1436 |
+
+ )
|
| 1437 |
+
+
|
| 1438 |
+
+ # 6. Prepare timesteps
|
| 1439 |
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 1440 |
+
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
| 1441 |
+
+ sigmas = None
|
| 1442 |
+
+ image_seq_len = latents.shape[1]
|
| 1443 |
+
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
| 1444 |
+
+ timesteps, num_inference_steps = retrieve_timesteps(
|
| 1445 |
+
+ self.scheduler,
|
| 1446 |
+
+ num_inference_steps,
|
| 1447 |
+
+ device,
|
| 1448 |
+
+ sigmas=sigmas,
|
| 1449 |
+
+ mu=mu,
|
| 1450 |
+
+ )
|
| 1451 |
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1452 |
+
+ self._num_timesteps = len(timesteps)
|
| 1453 |
+
+
|
| 1454 |
+
+ # 7. Denoising loop with KV caching
|
| 1455 |
+
+ # Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
|
| 1456 |
+
+ # Steps 1+: forward_kv_cached (reuse cached ref K/V)
|
| 1457 |
+
+ # No ref images: standard forward
|
| 1458 |
+
+ self.scheduler.set_begin_index(0)
|
| 1459 |
+
+ kv_cache = None
|
| 1460 |
+
+
|
| 1461 |
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1462 |
+
+ for i, t in enumerate(timesteps):
|
| 1463 |
+
+ if self.interrupt:
|
| 1464 |
+
+ continue
|
| 1465 |
+
+
|
| 1466 |
+
+ self._current_timestep = t
|
| 1467 |
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 1468 |
+
+
|
| 1469 |
+
+ if i == 0 and image_latents is not None:
|
| 1470 |
+
+ # Step 0: include ref tokens, extract KV cache
|
| 1471 |
+
+ latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
|
| 1472 |
+
+ latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
|
| 1473 |
+
+
|
| 1474 |
+
+ output, kv_cache = self.transformer(
|
| 1475 |
+
+ hidden_states=latent_model_input,
|
| 1476 |
+
+ timestep=timestep / 1000,
|
| 1477 |
+
+ guidance=None,
|
| 1478 |
+
+ encoder_hidden_states=prompt_embeds,
|
| 1479 |
+
+ txt_ids=text_ids,
|
| 1480 |
+
+ img_ids=latent_image_ids,
|
| 1481 |
+
+ joint_attention_kwargs=self.attention_kwargs,
|
| 1482 |
+
+ return_dict=False,
|
| 1483 |
+
+ kv_cache_mode="extract",
|
| 1484 |
+
+ num_ref_tokens=image_latents.shape[1],
|
| 1485 |
+
+ )
|
| 1486 |
+
+ noise_pred = output[0]
|
| 1487 |
+
+
|
| 1488 |
+
+ elif kv_cache is not None:
|
| 1489 |
+
+ # Steps 1+: use cached ref KV, no ref tokens in input
|
| 1490 |
+
+ noise_pred = self.transformer(
|
| 1491 |
+
+ hidden_states=latents.to(self.transformer.dtype),
|
| 1492 |
+
+ timestep=timestep / 1000,
|
| 1493 |
+
+ guidance=None,
|
| 1494 |
+
+ encoder_hidden_states=prompt_embeds,
|
| 1495 |
+
+ txt_ids=text_ids,
|
| 1496 |
+
+ img_ids=latent_ids,
|
| 1497 |
+
+ joint_attention_kwargs=self.attention_kwargs,
|
| 1498 |
+
+ return_dict=False,
|
| 1499 |
+
+ kv_cache=kv_cache,
|
| 1500 |
+
+ kv_cache_mode="cached",
|
| 1501 |
+
+ )[0]
|
| 1502 |
+
+
|
| 1503 |
+
+ else:
|
| 1504 |
+
+ # No reference images: standard forward
|
| 1505 |
+
+ noise_pred = self.transformer(
|
| 1506 |
+
+ hidden_states=latents.to(self.transformer.dtype),
|
| 1507 |
+
+ timestep=timestep / 1000,
|
| 1508 |
+
+ guidance=None,
|
| 1509 |
+
+ encoder_hidden_states=prompt_embeds,
|
| 1510 |
+
+ txt_ids=text_ids,
|
| 1511 |
+
+ img_ids=latent_ids,
|
| 1512 |
+
+ joint_attention_kwargs=self.attention_kwargs,
|
| 1513 |
+
+ return_dict=False,
|
| 1514 |
+
+ )[0]
|
| 1515 |
+
+
|
| 1516 |
+
+ # compute the previous noisy sample x_t -> x_t-1
|
| 1517 |
+
+ latents_dtype = latents.dtype
|
| 1518 |
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1519 |
+
+
|
| 1520 |
+
+ if latents.dtype != latents_dtype:
|
| 1521 |
+
+ if torch.backends.mps.is_available():
|
| 1522 |
+
+ latents = latents.to(latents_dtype)
|
| 1523 |
+
+
|
| 1524 |
+
+ if callback_on_step_end is not None:
|
| 1525 |
+
+ callback_kwargs = {}
|
| 1526 |
+
+ for k in callback_on_step_end_tensor_inputs:
|
| 1527 |
+
+ callback_kwargs[k] = locals()[k]
|
| 1528 |
+
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1529 |
+
+
|
| 1530 |
+
+ latents = callback_outputs.pop("latents", latents)
|
| 1531 |
+
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1532 |
+
+
|
| 1533 |
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1534 |
+
+ progress_bar.update()
|
| 1535 |
+
+
|
| 1536 |
+
+ if XLA_AVAILABLE:
|
| 1537 |
+
+ xm.mark_step()
|
| 1538 |
+
+
|
| 1539 |
+
+ # Clean up KV cache
|
| 1540 |
+
+ if kv_cache is not None:
|
| 1541 |
+
+ kv_cache.clear()
|
| 1542 |
+
+
|
| 1543 |
+
+ self._current_timestep = None
|
| 1544 |
+
+
|
| 1545 |
+
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
|
| 1546 |
+
+
|
| 1547 |
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
| 1548 |
+
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
| 1549 |
+
+ latents.device, latents.dtype
|
| 1550 |
+
+ )
|
| 1551 |
+
+ latents = latents * latents_bn_std + latents_bn_mean
|
| 1552 |
+
+ latents = self._unpatchify_latents(latents)
|
| 1553 |
+
+ if output_type == "latent":
|
| 1554 |
+
+ image = latents
|
| 1555 |
+
+ else:
|
| 1556 |
+
+ image = self.vae.decode(latents, return_dict=False)[0]
|
| 1557 |
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1558 |
+
+
|
| 1559 |
+
+ # Offload all models
|
| 1560 |
+
+ self.maybe_free_model_hooks()
|
| 1561 |
+
+
|
| 1562 |
+
+ if not return_dict:
|
| 1563 |
+
+ return (image,)
|
| 1564 |
+
+
|
| 1565 |
+
+ return Flux2PipelineOutput(images=image)
|
reference.jpg
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.13.0
|
| 2 |
+
diffusers==0.37.0
|
| 3 |
+
huggingface_hub==1.5.0
|
| 4 |
+
sentencepiece==0.2.1
|
| 5 |
+
protobuf==3.20.3
|
| 6 |
+
transformers==5.3.0
|
| 7 |
+
torch==2.9.1
|
| 8 |
+
Pillow
|
styles.json
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": "anime",
|
| 4 |
+
"name": "Anime",
|
| 5 |
+
"emoji": "🎌",
|
| 6 |
+
"description": "Clean high-fidelity anime character portrait",
|
| 7 |
+
"prompt": "Transform the uploaded person into a polished anime character portrait with clean cel-shaded line art, large expressive eyes, refined face simplification, crisp eyelash detailing, smooth gradient skin tones, vibrant but controlled color design, elegant hair separation, hand-drawn animation appeal, premium key visual composition, modern theatrical anime finish, and the luminous character-design quality of a top-tier Japanese animation studio."
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"id": "ps2_y2k_gaming",
|
| 11 |
+
"name": "PS2 / Y2K Gaming",
|
| 12 |
+
"emoji": "🎮",
|
| 13 |
+
"description": "Sixth-generation console render with glossy Y2K game aesthetics",
|
| 14 |
+
"prompt": "Transform the uploaded person into a PlayStation 2 era character render with early-2000s Y2K gaming aesthetics, low-to-mid polygon facial construction, crisp geometric forms, glossy skin shader, chunky specular highlights, baked ambient occlusion, subtle texture compression, dramatic menu-screen lighting, saturated console-era color grading, metallic futuristic fashion cues, stylish 2000s promotional energy, and the polished character-select screen look of a premium sixth-generation action game.",
|
| 15 |
+
"preview_seed": 101
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"id": "claymation_stop_motion",
|
| 19 |
+
"name": "Claymation / Stop Motion",
|
| 20 |
+
"emoji": "🧱",
|
| 21 |
+
"description": "Handmade stop-motion puppet built from sculpted clay",
|
| 22 |
+
"prompt": "Transform the uploaded person into a handcrafted claymation stop-motion puppet made from sculpted plasticine, visible fingerprints in the clay, slightly uneven handmade facial symmetry, pressed clothing details, tactile matte surfaces, miniature set lighting, practical stop-motion puppet proportions, tiny prop realism, shallow miniature depth of field, charming handcrafted imperfections, and the cozy premium look of a photographed studio stop-motion character.",
|
| 23 |
+
"preview_seed": 102
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"id": "pop_mart_blind_box",
|
| 27 |
+
"name": "Pop Mart / Blind Box Toy",
|
| 28 |
+
"emoji": "🧸",
|
| 29 |
+
"description": "Cute designer toy collectible with glossy vinyl finish",
|
| 30 |
+
"prompt": "Transform the uploaded person into a premium blind-box designer toy collectible, oversized head-to-body proportions, rounded limbs, simplified but recognizable facial features, glossy vinyl material, soft pastel luxury palette, tiny mouth, blush gradients, polished paint application, immaculate factory finish, cute shelf-display appeal, and the upscale commercial product-shot aesthetic of a high-end collectible art toy.",
|
| 31 |
+
"preview_seed": 103
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": "gta_loading_screen",
|
| 35 |
+
"name": "GTA Loading Screen",
|
| 36 |
+
"emoji": "🚘",
|
| 37 |
+
"description": "Bold illustrated splash art inspired by classic GTA loading screens",
|
| 38 |
+
"prompt": "Transform the uploaded person into a bold Grand Theft Auto style loading-screen illustration, confident character attitude, sharp graphic outlines, semi-realistic comic-book painting, smooth digital airbrush shading, punchy warm highlights, high-contrast skin rendering, glamorous urban styling, poster-like color blocking, polished splash-art finish, and the unmistakable blockbuster open-world crime-game loading screen aesthetic.",
|
| 39 |
+
"preview_seed": 104
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"id": "retro_90s_yearbook",
|
| 43 |
+
"name": "90s Yearbook / Retro",
|
| 44 |
+
"emoji": "📚",
|
| 45 |
+
"description": "Nostalgic 1990s school portrait with analog studio charm",
|
| 46 |
+
"prompt": "Transform the uploaded person into a 1990s yearbook portrait with mall-studio photography aesthetics, soft frontal flash, mottled blue or purple backdrop, gentle diffusion, natural analog skin tones, subtle film grain, slightly faded print quality, awkwardly sincere school-photo energy, retro grooming cues, thin halation, and the authentic printed look of a real 90s yearbook headshot.",
|
| 47 |
+
"preview_seed": 105
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"id": "ghibli_90s_anime",
|
| 51 |
+
"name": "Ghibli / 90s Anime",
|
| 52 |
+
"emoji": "🌿",
|
| 53 |
+
"description": "Warm hand-painted 1990s fantasy anime portrait",
|
| 54 |
+
"prompt": "Transform the uploaded person into a warm 1990s fantasy anime portrait with gentle cel animation linework, soft hand-painted color styling, expressive but calm eyes, youthful face simplification, watercolor atmosphere, delicate blush tones, magical natural light, pastoral fantasy warmth, cinematic emotional storytelling, and the nostalgic painted-film quality of a classic 1990s anime feature.",
|
| 55 |
+
"preview_seed": 106
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"id": "inflatable_balloon_art",
|
| 59 |
+
"name": "Inflatable / 3D Balloon Art",
|
| 60 |
+
"emoji": "🎈",
|
| 61 |
+
"description": "Glossy inflatable sculpture with oversized rounded forms",
|
| 62 |
+
"prompt": "Transform the uploaded person into a glossy inflatable balloon-art sculpture, oversized rounded anatomy, sealed vinyl seams, ultra-smooth reflective material, playful exaggerated volume, candy-like color design, high-specular reflections, inflated clothing shapes, whimsical premium 3D render quality, and the luxurious visual language of high-end contemporary balloon sculpture art.",
|
| 63 |
+
"preview_seed": 107
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"id": "pixel_art_16bit",
|
| 67 |
+
"name": "16-Bit Pixel Art",
|
| 68 |
+
"emoji": "👾",
|
| 69 |
+
"description": "Detailed 16-bit console sprite portrait",
|
| 70 |
+
"prompt": "Transform the uploaded person into detailed 16-bit pixel art, carefully reduced facial features for sprite readability, limited retro palette, crisp pixel clusters, deliberate dithering, SNES-era RPG portrait design, bold silhouette readability, expressive eyes built from minimal pixel forms, clean sprite shading, cartridge-era charm, and the polished character-portrait look of a premium 1990s console role-playing game.",
|
| 71 |
+
"preview_seed": 108
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"id": "clay_sculpture",
|
| 75 |
+
"name": "Clay",
|
| 76 |
+
"emoji": "🪵",
|
| 77 |
+
"description": "Handcrafted clay or ceramic bust with tactile sculpted detail",
|
| 78 |
+
"prompt": "Transform the uploaded person into a handcrafted clay sculpture portrait, tactile earthen material, sculpted cheekbones and hair ridges, ceramic or raw-clay texture variation, hand-tool marks, subtle cracks, carved garment detail, grounded natural palette, gallery-style side lighting, and the substantial physical presence of a carefully sculpted clay bust.",
|
| 79 |
+
"preview_seed": 109
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"id": "toy_collectible",
|
| 83 |
+
"name": "Toy",
|
| 84 |
+
"emoji": "🧸",
|
| 85 |
+
"description": "Stylized collectible toy figure with premium retail finish",
|
| 86 |
+
"prompt": "Transform the uploaded person into a stylized collectible toy figure, premium molded plastic materials, clean manufactured edges, slightly enlarged head, simplified but recognizable likeness, painted toy detailing, glossy eyes, shelf-ready silhouette, commercial product lighting, and the finish quality of an officially released limited-edition character toy.",
|
| 87 |
+
"preview_seed": 110
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"id": "sims_style",
|
| 91 |
+
"name": "The Sims Style",
|
| 92 |
+
"emoji": "💚",
|
| 93 |
+
"description": "Friendly life-simulation avatar with polished game-render styling",
|
| 94 |
+
"prompt": "Transform the uploaded person into a life-simulation game avatar with polished The Sims inspired styling, bright approachable face design, clean character-creator grooming, smooth game-render skin, slightly idealized proportions, cheerful suburban presentation, crisp digital lighting, customizable-avatar appeal, and the unmistakable promotional look of a premium life-sim character portrait.",
|
| 95 |
+
"preview_seed": 111
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"id": "gta_character_art",
|
| 99 |
+
"name": "GTA Style",
|
| 100 |
+
"emoji": "🔫",
|
| 101 |
+
"description": "High-impact urban crime-game promotional character art",
|
| 102 |
+
"prompt": "Transform the uploaded person into stylized open-world crime-game key art with GTA-inspired visual language, sharp facial planes, smooth posterized shading, assertive street attitude, bold edge highlights, vivid urban contrast, semi-realistic comic painting, polished marketing illustration finish, and the high-impact character-poster energy of a blockbuster crime saga game.",
|
| 103 |
+
"preview_seed": 112
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"id": "minecraft_style",
|
| 107 |
+
"name": "Minecraft Style",
|
| 108 |
+
"emoji": "🟩",
|
| 109 |
+
"description": "Block-built voxel avatar with iconic sandbox-game geometry",
|
| 110 |
+
"prompt": "Transform the uploaded person into a Minecraft-inspired voxel character, square head, cubic limbs, block-built anatomy, pixel-textured clothing, low-resolution texture mapping, readable blocky hairstyle translation, simplified face, bright sandbox-world lighting, chunky geometric construction, and the instantly recognizable look of a polished block-game avatar.",
|
| 111 |
+
"preview_seed": 113
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"id": "roblox_style",
|
| 115 |
+
"name": "Roblox Style",
|
| 116 |
+
"emoji": "🟥",
|
| 117 |
+
"description": "Playful modular avatar with clean platform-game styling",
|
| 118 |
+
"prompt": "Transform the uploaded person into a Roblox-inspired avatar portrait, modular toy-like body construction, simplified geometry, glossy materials, playful proportions, readable face design, bright platform-game lighting, fashionable avatar styling, clean digital render finish, and the vibrant social-gaming look of a premium customizable online avatar.",
|
| 119 |
+
"preview_seed": 114
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"id": "vhs_style",
|
| 123 |
+
"name": "VHS Style",
|
| 124 |
+
"emoji": "📼",
|
| 125 |
+
"description": "Degraded analog tape still with authentic VHS artifacts",
|
| 126 |
+
"prompt": "Transform the uploaded person into a VHS-era analog video still, interlaced scanlines, tape noise, chromatic bleed, tracking distortion, magnetic dropouts, dated camcorder color response, low-resolution softness, blown highlights, shadow crush, slight timestamp-era nostalgia, and the authentic degraded feeling of a paused home-video frame transferred from tape.",
|
| 127 |
+
"preview_seed": 115
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"id": "polaroid_style",
|
| 131 |
+
"name": "Polaroid Style",
|
| 132 |
+
"emoji": "🖼️",
|
| 133 |
+
"description": "Instant-film portrait with soft flash and chemical color shifts",
|
| 134 |
+
"prompt": "Transform the uploaded person into a classic Polaroid instant-photo portrait, soft flash falloff, creamy analog color shifts, lightly washed highlights, vintage chemical film response, subtle edge fading, gentle grain, tactile print nostalgia, candid warmth, imperfect instant-camera sharpness, and the authentic charm of a real instant photograph.",
|
| 135 |
+
"preview_seed": 116
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"id": "webtoon_style",
|
| 139 |
+
"name": "Webtoon Style",
|
| 140 |
+
"emoji": "📱",
|
| 141 |
+
"description": "Glossy serialized-comic portrait with polished webtoon rendering",
|
| 142 |
+
"prompt": "Transform the uploaded person into a polished webtoon character portrait, elegant clean line art, luminous skin rendering, expressive stylized eyes, soft gradient shading, romantic dramatic atmosphere, fashion-forward character styling, glossy digital coloring, refined facial simplification, and the premium visual quality of a top-tier Korean webtoon cover panel.",
|
| 143 |
+
"preview_seed": 117
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"id": "pixar_style",
|
| 147 |
+
"name": "Pixar Style",
|
| 148 |
+
"emoji": "✨",
|
| 149 |
+
"description": "Cinematic stylized 3D animated character render",
|
| 150 |
+
"prompt": "Transform the uploaded person into a stylized cinematic 3D animated character with Pixar-inspired appeal, expressive eyes, carefully designed facial proportions, believable subsurface scattering, smooth skin shading, clean hair groom, emotionally readable expression, art-directed color harmony, cinematic key-and-fill lighting, and the premium theatrical quality of a feature-animation hero render.",
|
| 151 |
+
"preview_seed": 118
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"id": "wizard_school_fantasy",
|
| 155 |
+
"name": "Wizard School Identity",
|
| 156 |
+
"emoji": "🪄",
|
| 157 |
+
"description": "Magical academy identity portrait in a spellbound school world",
|
| 158 |
+
"prompt": "Transform the uploaded person into a magical boarding-school fantasy character portrait, scholarly robe tailoring, enchanted academic atmosphere, candlelit old-world interiors, brass and parchment details, subtle magical particles, house-color accents, moody castle lighting, elegant fantasy realism, and the immersive identity-shift feeling of belonging inside a grand wizard school universe.",
|
| 159 |
+
"preview_seed": 119
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"id": "dune_desert_epic",
|
| 163 |
+
"name": "Dune Identity",
|
| 164 |
+
"emoji": "🏜️",
|
| 165 |
+
"description": "Mythic desert sci-fi identity with austere cinematic grandeur",
|
| 166 |
+
"prompt": "Transform the uploaded person into a desert-epic science-fiction character, windswept arid atmosphere, monumental cinematic seriousness, sand-toned palette, survivalist desert garments, futuristic nomadic detailing, austere regal styling, harsh sunlit contrast, drifting dust, mythic scale, and the solemn high-fashion science-fantasy presence of a legendary desert-world figure.",
|
| 167 |
+
"preview_seed": 120
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"id": "viking_saga",
|
| 171 |
+
"name": "Viking Identity",
|
| 172 |
+
"emoji": "🪓",
|
| 173 |
+
"description": "Rugged Norse saga portrait with warrior realism",
|
| 174 |
+
"prompt": "Transform the uploaded person into a Viking saga character portrait, weathered Nordic styling, braided hair or fur-lined garments, carved metal detail, worn leather and wool textures, cold-climate realism, stormy dramatic light, warrior dignity, historical-fantasy grit, and the fierce mythic aura of a cinematic Norse hero.",
|
| 175 |
+
"preview_seed": 121
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"id": "superhero_identity",
|
| 179 |
+
"name": "Super Hero Identity",
|
| 180 |
+
"emoji": "🦸",
|
| 181 |
+
"description": "Blockbuster hero reveal with stylized cinematic power",
|
| 182 |
+
"prompt": "Transform the uploaded person into a hyper-stylized superhero character portrait, iconic costume-design cues, bold chest-and-shoulder silhouette, cinematic heroic lighting, dramatic backlight, high-energy color blocking, polished comic-to-film realism, powerful confident expression, and the larger-than-life franchise-poster aura of a modern blockbuster hero reveal.",
|
| 183 |
+
"preview_seed": 122
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"id": "paper_cutout",
|
| 187 |
+
"name": "Paper Cutout",
|
| 188 |
+
"emoji": "✂️",
|
| 189 |
+
"description": "Layered paper collage with tactile handcrafted depth",
|
| 190 |
+
"prompt": "Transform the uploaded person into a layered paper-cut illustration, hand-cut cardstock shapes, stacked paper planes, crisp scissor-cut edges, soft cast shadows between layers, simplified graphic facial features, subtle paper fiber texture, handcrafted composition, and the elegant dimensional charm of premium artisanal papercraft artwork.",
|
| 191 |
+
"preview_seed": 123
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"id": "baroque",
|
| 195 |
+
"name": "Baroque",
|
| 196 |
+
"emoji": "👑",
|
| 197 |
+
"description": "Ornate old-master portrait with dramatic Baroque lighting",
|
| 198 |
+
"prompt": "Transform the uploaded person into a grand Baroque portrait painting, rich velvet and silk costume textures, ornate gold embellishment, museum-grade oil-paint surface, dramatic chiaroscuro, deep tenebrist shadows, warm candlelit highlights, aristocratic posture, old-master brushwork, and the opulent 17th-century grandeur of a prestigious European court portrait.",
|
| 199 |
+
"preview_seed": 124
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"id": "dark_fantasy",
|
| 203 |
+
"name": "Dark Fantasy",
|
| 204 |
+
"emoji": "🕯️",
|
| 205 |
+
"description": "Brooding gothic fantasy portrait with ominous atmosphere",
|
| 206 |
+
"prompt": "Transform the uploaded person into a dark fantasy character portrait, gothic worldbuilding atmosphere, ominous candlelit shadows, ancient armor or ritual garments, ash-and-blood color accents, weathered surfaces, haunted elegance, moody volumetric haze, cursed-kingdom tone, and the brooding visual depth of a prestige dark-fantasy epic.",
|
| 207 |
+
"preview_seed": 125
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"id": "cyberpunk_neon_noir",
|
| 211 |
+
"name": "Cyberpunk / Neon Noir",
|
| 212 |
+
"emoji": "🌆",
|
| 213 |
+
"description": "Rain-soaked futuristic noir lit by neon signage",
|
| 214 |
+
"prompt": "Transform the uploaded person into a neon-noir cyberpunk portrait, rain-slick reflections, magenta and cyan practical lighting, holographic signage glow, moody urban darkness, reflective synthetic fabrics, futuristic street-culture styling, soft atmospheric mist, glossy skin highlights from neon, and the immersive nighttime mood of a premium neon-soaked sci-fi thriller.",
|
| 215 |
+
"preview_seed": 126
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"id": "ukiyoe_style",
|
| 219 |
+
"name": "Ukiyo-e",
|
| 220 |
+
"emoji": "🌊",
|
| 221 |
+
"description": "Elegant Japanese floating-world portrait print",
|
| 222 |
+
"prompt": "Transform the uploaded person into a refined ukiyo-e portrait, graceful flattened perspective, controlled ink outlines, flowing garment silhouettes, muted mineral pigments, washi-paper feel, decorative hair treatment, balanced negative space, Edo-period print sensibility, and the serene crafted beauty of a classic Japanese floating-world print.",
|
| 223 |
+
"preview_seed": 127
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"id": "woodblock_print",
|
| 227 |
+
"name": "Woodblock",
|
| 228 |
+
"emoji": "🪵",
|
| 229 |
+
"description": "Hand-pulled relief print with carved texture and ink variation",
|
| 230 |
+
"prompt": "Transform the uploaded person into a traditional woodblock print portrait, carved-line texture, visible ink transfer irregularities, reduced tonal shapes, tactile handmade printmaking character, rough-paper grain, artisan block-carved detail, limited print palette, and the physical analog feeling of a carefully hand-pulled relief print.",
|
| 231 |
+
"preview_seed": 128
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"id": "action_figure",
|
| 235 |
+
"name": "Action Figure",
|
| 236 |
+
"emoji": "🦾",
|
| 237 |
+
"description": "Collector-grade articulated figure based on the subject",
|
| 238 |
+
"prompt": "Transform the uploaded person into a premium articulated action figure, realistic molded head sculpt, visible joint engineering, semi-matte plastic body, painted costume detail, collector-grade product lighting, hero-franchise presentation, crisp accessory-ready design language, and the convincing look of a high-end retail action figure based on the subject.",
|
| 239 |
+
"preview_seed": 129
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"id": "doll_packaging",
|
| 243 |
+
"name": "Doll Packaging",
|
| 244 |
+
"emoji": "🎁",
|
| 245 |
+
"description": "Fashion doll presented inside premium retail box packaging",
|
| 246 |
+
"prompt": "Transform the uploaded person into a fashion doll presentation inside premium retail packaging, stylized doll proportions, glossy plastic face paint, immaculate rooted-hair appearance, product-window box framing, coordinated accessories, bright commercial toy photography lighting, shelf-display glamour, and the fully merchandised elegance of a high-end doll-box release.",
|
| 247 |
+
"preview_seed": 130
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"id": "crochet_style",
|
| 251 |
+
"name": "Crochet",
|
| 252 |
+
"emoji": "🧶",
|
| 253 |
+
"description": "Hand-crocheted doll portrait with visible looped stitching",
|
| 254 |
+
"prompt": "Transform the uploaded person into a handcrafted crochet figure portrait, visible looped stitches, yarn-built facial features, soft stuffed volume, carefully stitched clothing detail, tactile fiber texture, warm craft-table atmosphere, subtle handmade irregularity, and the lovable artisanal quality of an expertly crocheted doll.",
|
| 255 |
+
"preview_seed": 131
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"id": "yarn_plush",
|
| 259 |
+
"name": "Yarn",
|
| 260 |
+
"emoji": "🪡",
|
| 261 |
+
"description": "Soft plush character built from thick colorful yarn",
|
| 262 |
+
"prompt": "Transform the uploaded person into a plush yarn-crafted character, fluffy wool fibers, soft textile construction, rounded stuffed-toy anatomy, fuzzy tactile surface, adorable stitched detailing, cozy handmade softness, gentle nursery-style lighting, and the charming volume of a premium yarn plush character.",
|
| 263 |
+
"preview_seed": 132
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"id": "creepy_doll",
|
| 267 |
+
"name": "Creepy Doll",
|
| 268 |
+
"emoji": "🪆",
|
| 269 |
+
"description": "Uncanny haunted-doll portrait with aged porcelain detail",
|
| 270 |
+
"prompt": "Transform the uploaded person into an eerie creepy-doll portrait, glassy unsettling eyes, cracked porcelain texture, antique doll construction, faintly uncanny smile, distressed costume details, dim haunted-house lighting, vintage-toy horror atmosphere, subtle age and wear marks, and the disturbing but carefully crafted presence of a haunted collectible doll.",
|
| 271 |
+
"preview_seed": 133
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"id": "ultra_ugly_caricature",
|
| 275 |
+
"name": "Ultra Ugly Caricature",
|
| 276 |
+
"emoji": "🤪",
|
| 277 |
+
"description": "Extreme grotesque caricature pushed for maximum comedic distortion",
|
| 278 |
+
"prompt": "Transform the uploaded person into an ultra-exaggerated ugly caricature, intentionally awkward proportions, amplified facial asymmetry, oversized nose and jaw cues when naturally present, absurd comedic distortion, lumpy stylization, grotesque editorial-cartoon boldness, ugly-but-recognizable likeness, and the maximum-pushed comedic exaggeration of a ruthless caricature artist rendering.",
|
| 279 |
+
"preview_seed": 134
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"id": "classic_caricature",
|
| 283 |
+
"name": "Caricature",
|
| 284 |
+
"emoji": "🎭",
|
| 285 |
+
"description": "Playful editorial caricature that exaggerates signature features",
|
| 286 |
+
"prompt": "Transform the uploaded person into a polished caricature portrait, amplified signature facial traits, playful proportions, expressive linework, satirical illustration charm, painterly editorial finish, witty character design, lively exaggerated smile and eye shapes, and the appealing handcrafted look of a professional caricature artist creating a recognizable stylized likeness.",
|
| 287 |
+
"preview_seed": 135
|
| 288 |
+
}
|
| 289 |
+
]
|