artificialguybr commited on
Commit
55acca5
·
verified ·
1 Parent(s): 1ba7952

Set up Style My Portrait branding and app

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. README.md +2 -2
  3. SpaceGrotesk-Bold.ttf +0 -0
  4. app.py +271 -0
  5. flux2_klein_kv.patch +1565 -0
  6. reference.jpg +3 -0
  7. requirements.txt +8 -0
  8. styles.json +289 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ reference.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Style My Portrait
3
- emoji: 📈
4
- colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.13.0
 
1
  ---
2
  title: Style My Portrait
3
+ emoji: 🖼️
4
+ colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.13.0
SpaceGrotesk-Bold.ttf ADDED
Binary file (86.5 kB). View file
 
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import subprocess
5
+ import spaces
6
+
7
+ HF_TOKEN = os.environ.get("HF_TOKEN")
8
+
9
+ def apply_patch():
10
+ import diffusers
11
+ site_packages = os.path.dirname(diffusers.__file__)
12
+ patch_file = os.path.join(os.path.dirname(__file__), "flux2_klein_kv.patch")
13
+ if os.path.exists(patch_file):
14
+ result = subprocess.run(
15
+ ["patch", "-p2", "--forward", "--batch"],
16
+ cwd=os.path.dirname(site_packages),
17
+ stdin=open(patch_file),
18
+ capture_output=True,
19
+ text=True,
20
+ )
21
+
22
+ apply_patch()
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from diffusers.pipelines.flux2.pipeline_flux2_klein_kv import Flux2KleinKVPipeline
28
+ import gradio as gr
29
+
30
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
31
+ STYLES_PATH = os.path.join(APP_DIR, "styles.json")
32
+ REFERENCE_PATH = os.path.join(APP_DIR, "reference.jpg")
33
+ MODEL_ID = "black-forest-labs/FLUX.2-klein-9b-kv"
34
+ MANDATORY_PROMPT = (
35
+ "You must preserve the exact identity, facial features, body proportions, and pose of the subject. "
36
+ "Keep the original camera framing, object positions, and full scene layout unchanged. "
37
+ "Apply only the requested visual style while maintaining total consistency and realism."
38
+ )
39
+
40
+ dtype = torch.bfloat16
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ MAX_SEED = np.iinfo(np.int32).max
43
+
44
+ with open(STYLES_PATH, "r", encoding="utf-8") as f:
45
+ STYLES = json.load(f)
46
+
47
+ pipe = Flux2KleinKVPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype, token=HF_TOKEN)
48
+ pipe.to("cuda")
49
+
50
+ PREVIEW_CACHE = {}
51
+
52
+ def prepare_image(image, max_size=1024):
53
+ iw, ih = image.size
54
+ ar = iw / ih
55
+ if ar >= 1:
56
+ w = max_size
57
+ h = round(max_size / ar / 16) * 16
58
+ else:
59
+ h = max_size
60
+ w = round(max_size * ar / 16) * 16
61
+ w, h = max(256, min(max_size, w)), max(256, min(max_size, h))
62
+ return image.resize((w, h), Image.LANCZOS), w, h
63
+
64
+ def build_prompt(style_prompt):
65
+ return f"{MANDATORY_PROMPT} {style_prompt}"
66
+
67
+ @spaces.GPU(duration=85)
68
+ def transform_image(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
69
+ if image is None:
70
+ raise gr.Error("Upload a photo first!")
71
+ style = next((s for s in STYLES if s["id"] == style_id), None)
72
+ if not style:
73
+ raise gr.Error("Select a valid style!")
74
+ if randomize_seed:
75
+ seed = random.randint(0, MAX_SEED)
76
+ original_resized, w, h = prepare_image(image)
77
+ generator = torch.Generator(device=device).manual_seed(seed)
78
+ progress(0.2, desc=f"Applying {style['name']} style...")
79
+ result = pipe(
80
+ prompt=build_prompt(style["prompt"]),
81
+ image=[original_resized],
82
+ height=h,
83
+ width=w,
84
+ num_inference_steps=num_steps,
85
+ generator=generator,
86
+ ).images[0]
87
+ progress(0.9, desc="Preparing before/after slider...")
88
+ return (original_resized, result), seed, original_resized, result
89
+
90
+ @spaces.GPU(duration=120)
91
+ def generate_previews(progress=gr.Progress(track_tqdm=True)):
92
+ global PREVIEW_CACHE
93
+ if PREVIEW_CACHE:
94
+ return PREVIEW_CACHE
95
+ if not os.path.exists(REFERENCE_PATH):
96
+ PREVIEW_CACHE = {s["id"]: None for s in STYLES}
97
+ return PREVIEW_CACHE
98
+ ref_img = Image.open(REFERENCE_PATH).convert("RGB")
99
+ _, w, h = prepare_image(ref_img)
100
+ for i, style in enumerate(STYLES):
101
+ progress(i / len(STYLES), desc=f"Generating preview: {style['name']}")
102
+ generator = torch.Generator(device=device).manual_seed(style.get("preview_seed", 42))
103
+ try:
104
+ res = pipe(
105
+ prompt=build_prompt(style["prompt"]),
106
+ image=[ref_img],
107
+ height=h,
108
+ width=w,
109
+ num_inference_steps=4,
110
+ generator=generator,
111
+ ).images[0]
112
+ PREVIEW_CACHE[style["id"]] = res
113
+ except Exception as e:
114
+ print(f"Error generating preview for {style['name']}: {e}")
115
+ PREVIEW_CACHE[style["id"]] = None
116
+ return PREVIEW_CACHE
117
+
118
+ def load_previews_to_gallery():
119
+ previews = generate_previews()
120
+ gallery_items = []
121
+ for style in STYLES:
122
+ img = previews.get(style["id"])
123
+ if img:
124
+ gallery_items.append((img, f"{style['emoji']} {style['name']}"))
125
+ else:
126
+ placeholder = Image.new("RGB", (512, 512), (20, 18, 30))
127
+ gallery_items.append((placeholder, f"{style['emoji']} {style['name']}"))
128
+ return gallery_items
129
+
130
+ def select_style_from_gallery(evt: gr.SelectData):
131
+ style = STYLES[evt.index]
132
+ return style["id"]
133
+
134
+ css = """
135
+ @import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
136
+ .gradio-container { background: radial-gradient(circle at top left, #fff4e8 0%, #f7ddc0 32%, #d79a72 68%, #8f4f3d 100%) !important; max-width: 1120px !important; margin: 0 auto !important; font-family: 'Space Grotesk', sans-serif !important; }
137
+ .main-title h1 { text-align: center; color: #5a2416 !important; font-size: 2.7em !important; font-weight: 700 !important; letter-spacing: -1px; margin-bottom: 0 !important; }
138
+ .subtitle p { text-align: center; color: #6f3a2a !important; font-size: 1.05em !important; margin-top: 0 !important; }
139
+ .step-guide p { text-align: center; color: #6b3527 !important; font-size: 1.03em !important; margin: 6px 0 10px !important; }
140
+ .credit-line p { text-align: center; color: #7b3f2d !important; font-size: 0.98em !important; margin: 4px 0 20px !important; }
141
+ .credit-line a { color: #8f2d1f !important; text-decoration: none !important; font-weight: 700 !important; }
142
+ .credit-line a:hover { text-decoration: underline !important; }
143
+ #style-gallery { border: 1px solid #b96b4c !important; border-radius: 12px !important; background: #fff8f0 !important; }
144
+ #style-gallery .grid-wrap { gap: 8px !important; }
145
+ #style-gallery .thumbnail-item { border: 2px solid transparent !important; border-radius: 10px !important; transition: all 0.2s ease !important; }
146
+ #style-gallery .thumbnail-item:hover { border-color: #cf6e43aa !important; transform: translateY(-2px); }
147
+ #style-gallery .thumbnail-item.selected,
148
+ #style-gallery .thumbnail-item[aria-selected="true"] {
149
+ border-color: #bf5f34 !important;
150
+ box-shadow: 0 0 0 2px #cf6e4380, 0 10px 26px #8f4f3d40 !important;
151
+ transform: translateY(-2px) scale(1.02) !important;
152
+ }
153
+ #input-img { border: 2px dashed #c57d57 !important; border-radius: 12px !important; background: #fff8f0 !important; min-height: 320px; }
154
+ #output-slider { border: 1px solid #c57d57 !important; border-radius: 12px !important; background: #fff8f0 !important; overflow: hidden !important; }
155
+ #go-btn { background: linear-gradient(135deg, #8f2d1f, #d76a3f) !important; color: #fff8f0 !important; font-weight: 700 !important; font-size: 1.15em !important; min-height: 52px !important; border: none !important; border-radius: 12px !important; box-shadow: 0 4px 20px #8f2d1f33; transition: all 0.2s ease !important; }
156
+ #go-btn:hover { box-shadow: 0 6px 26px #8f2d1f4d; transform: translateY(-2px); }
157
+ #dl-btn { background: #fff3e6 !important; color: #8f2d1f !important; border: 1px solid #c57d57 !important; border-radius: 10px !important; }
158
+ .progress-bar { background-color: #d76a3f !important; }
159
+ .progress-bar-wrap { background-color: #f3d8bd !important; }
160
+ * { --neutral-50: #fff8f0 !important; --neutral-100: #fff3e6 !important; --neutral-200: #e7bb96 !important; }
161
+ .dark { --body-background-fill: #f7ddc0; }
162
+ footer { display: none !important; }
163
+ """
164
+
165
+ with gr.Blocks(title="Style My Portrait", css=css, theme=gr.themes.Base(
166
+ primary_hue=gr.themes.colors.orange,
167
+ secondary_hue=gr.themes.colors.red,
168
+ neutral_hue=gr.themes.colors.stone,
169
+ font=gr.themes.GoogleFont("Space Grotesk"),
170
+ ).set(
171
+ body_background_fill="#f7ddc0",
172
+ body_background_fill_dark="#f7ddc0",
173
+ block_background_fill="#fff8f0",
174
+ block_background_fill_dark="#fff8f0",
175
+ block_border_color="#c57d57",
176
+ block_border_color_dark="#c57d57",
177
+ block_label_text_color="#8f2d1f",
178
+ block_label_text_color_dark="#8f2d1f",
179
+ block_title_text_color="#8f2d1f",
180
+ block_title_text_color_dark="#8f2d1f",
181
+ body_text_color="#5a2416",
182
+ body_text_color_dark="#5a2416",
183
+ button_primary_background_fill="#d76a3f",
184
+ button_primary_background_fill_dark="#d76a3f",
185
+ button_primary_text_color="#fff8f0",
186
+ button_primary_text_color_dark="#fff8f0",
187
+ input_background_fill="#fff3e6",
188
+ input_background_fill_dark="#fff3e6",
189
+ input_border_color="#c57d57",
190
+ input_border_color_dark="#c57d57",
191
+ border_color_accent="#d76a3f",
192
+ border_color_accent_dark="#d76a3f",
193
+ border_color_primary="#c57d57",
194
+ border_color_primary_dark="#c57d57",
195
+ background_fill_secondary="#fff8f0",
196
+ background_fill_secondary_dark="#fff8f0",
197
+ background_fill_primary="#f7ddc0",
198
+ background_fill_primary_dark="#f7ddc0",
199
+ shadow_drop="none",
200
+ shadow_drop_lg="none",
201
+ slider_color="#d76a3f",
202
+ slider_color_dark="#d76a3f",
203
+ checkbox_background_color="#fff3e6",
204
+ checkbox_background_color_dark="#fff3e6",
205
+ checkbox_background_color_selected="#d76a3f",
206
+ checkbox_background_color_selected_dark="#d76a3f",
207
+ )) as demo:
208
+
209
+ gr.Markdown("# Style My Portrait", elem_classes="main-title")
210
+ gr.Markdown("Turn your portrait into polished visual styles with FLUX.2 Klein", elem_classes="subtitle")
211
+ gr.Markdown("**1)** Upload your photo • **2)** Pick a style from the gallery • **3)** Click transform", elem_classes="step-guide")
212
+ gr.Markdown(
213
+ "Built by [@artificialguybr](https://twitter.com/artificialguybr) • Explore more image editing and image generation prompts at [artificialguy.com](https://artificialguy.com) and [findgoodprompt.com](https://findgoodprompt.com)",
214
+ elem_classes="credit-line",
215
+ )
216
+
217
+ selected_style_id = gr.State(STYLES[0]["id"])
218
+ original_state = gr.State(None)
219
+ enhanced_state = gr.State(None)
220
+
221
+ with gr.Row(equal_height=True):
222
+ with gr.Column(scale=1):
223
+ input_image = gr.Image(label="📸 Upload & Preview", type="pil", elem_id="input-img")
224
+ with gr.Accordion("⚙️ Settings", open=False):
225
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
226
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
227
+ num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=20, step=1, value=4)
228
+ go_btn = gr.Button("🚀 Transform", elem_id="go-btn", variant="primary")
229
+
230
+ with gr.Column(scale=1):
231
+ gr.Markdown("### 🎨 Gallery", elem_classes="subtitle")
232
+ style_gallery = gr.Gallery(
233
+ label="Style Gallery",
234
+ show_label=False,
235
+ elem_id="style-gallery",
236
+ columns=[4], rows=[2],
237
+ height=360,
238
+ object_fit="cover",
239
+ allow_preview=False,
240
+ value=load_previews_to_gallery()
241
+ )
242
+
243
+ style_gallery.select(
244
+ fn=select_style_from_gallery,
245
+ inputs=[],
246
+ outputs=[selected_style_id]
247
+ )
248
+
249
+ gr.Markdown("### 🖼️ Result", elem_classes="subtitle")
250
+ with gr.Row():
251
+ output_image = gr.ImageSlider(label="Before / After", type="pil", elem_id="output-slider", slider_position=50)
252
+ with gr.Row():
253
+ dl_btn = gr.DownloadButton("📥 Download Image", elem_id="dl-btn", visible=False)
254
+
255
+ def on_generate(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
256
+ comparison, seed, orig, enh = transform_image(image, style_id, seed, randomize_seed, num_steps, progress)
257
+ return comparison, seed, orig, enh, gr.update(visible=True, value=enh)
258
+
259
+ go_btn.click(
260
+ fn=on_generate,
261
+ inputs=[input_image, selected_style_id, seed, randomize_seed, num_inference_steps],
262
+ outputs=[output_image, seed, original_state, enhanced_state, dl_btn],
263
+ )
264
+
265
+ input_image.change(
266
+ fn=lambda: (None, gr.update(visible=False), None, None),
267
+ inputs=[],
268
+ outputs=[output_image, dl_btn, original_state, enhanced_state],
269
+ )
270
+
271
+ demo.launch(ssr_mode=False)
flux2_klein_kv.patch ADDED
@@ -0,0 +1,1565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
2
+ index 546fbe57b..0be7b8166 100644
3
+ --- a/src/diffusers/__init__.py
4
+ +++ b/src/diffusers/__init__.py
5
+ @@ -510,6 +510,7 @@ else:
6
+ "EasyAnimateControlPipeline",
7
+ "EasyAnimateInpaintPipeline",
8
+ "EasyAnimatePipeline",
9
+ + "Flux2KleinKVPipeline",
10
+ "Flux2KleinPipeline",
11
+ "Flux2Pipeline",
12
+ "FluxControlImg2ImgPipeline",
13
+ @@ -1266,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
14
+ EasyAnimateControlPipeline,
15
+ EasyAnimateInpaintPipeline,
16
+ EasyAnimatePipeline,
17
+ + Flux2KleinKVPipeline,
18
+ Flux2KleinPipeline,
19
+ Flux2Pipeline,
20
+ FluxControlImg2ImgPipeline,
21
+ diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py
22
+ index f77498c74..d2ecba583 100644
23
+ --- a/src/diffusers/models/transformers/transformer_flux2.py
24
+ +++ b/src/diffusers/models/transformers/transformer_flux2.py
25
+ @@ -40,6 +40,193 @@ from ..normalization import AdaLayerNormContinuous
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ +class Flux2KVLayerCache:
30
+ + """Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
31
+ +
32
+ + Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step.
33
+ + Tensor format: (batch_size, num_ref_tokens, num_heads, head_dim).
34
+ + """
35
+ +
36
+ + def __init__(self):
37
+ + self.k_ref: torch.Tensor | None = None
38
+ + self.v_ref: torch.Tensor | None = None
39
+ +
40
+ + def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
41
+ + """Store reference token K/V."""
42
+ + self.k_ref = k_ref
43
+ + self.v_ref = v_ref
44
+ +
45
+ + def get(self) -> tuple[torch.Tensor, torch.Tensor]:
46
+ + """Retrieve cached reference token K/V."""
47
+ + if self.k_ref is None:
48
+ + raise RuntimeError("KV cache has not been populated yet.")
49
+ + return self.k_ref, self.v_ref
50
+ +
51
+ + def clear(self):
52
+ + self.k_ref = None
53
+ + self.v_ref = None
54
+ +
55
+ +
56
+ +class Flux2KVCache:
57
+ + """Container for all layers' reference-token KV caches.
58
+ +
59
+ + Holds separate cache lists for double-stream and single-stream transformer blocks.
60
+ + """
61
+ +
62
+ + def __init__(self, num_double_layers: int, num_single_layers: int):
63
+ + self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
64
+ + self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
65
+ + self.num_ref_tokens: int = 0
66
+ +
67
+ + def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
68
+ + return self.double_block_caches[layer_idx]
69
+ +
70
+ + def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
71
+ + return self.single_block_caches[layer_idx]
72
+ +
73
+ + def clear(self):
74
+ + for cache in self.double_block_caches:
75
+ + cache.clear()
76
+ + for cache in self.single_block_caches:
77
+ + cache.clear()
78
+ + self.num_ref_tokens = 0
79
+ +
80
+ +
81
+ +def _flux2_kv_causal_attention(
82
+ + query: torch.Tensor,
83
+ + key: torch.Tensor,
84
+ + value: torch.Tensor,
85
+ + num_txt_tokens: int,
86
+ + num_ref_tokens: int,
87
+ + kv_cache: Flux2KVLayerCache | None = None,
88
+ + backend=None,
89
+ +) -> torch.Tensor:
90
+ + """Causal attention for KV caching where reference tokens only self-attend.
91
+ +
92
+ + All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
93
+ +
94
+ + Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens,
95
+ + ref tokens only attend to themselves.
96
+ + With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected between txt and img.
97
+ + """
98
+ + # No ref tokens and no cache — standard full attention
99
+ + if num_ref_tokens == 0 and kv_cache is None:
100
+ + return dispatch_attention_fn(query, key, value, backend=backend)
101
+ +
102
+ + if kv_cache is not None:
103
+ + # Cached mode: inject ref K/V between txt and img
104
+ + k_ref, v_ref = kv_cache.get()
105
+ +
106
+ + k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
107
+ + v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
108
+ +
109
+ + return dispatch_attention_fn(query, k_all, v_all, backend=backend)
110
+ +
111
+ + # Extract mode: ref tokens self-attend, txt+img attend to all
112
+ + ref_start = num_txt_tokens
113
+ + ref_end = num_txt_tokens + num_ref_tokens
114
+ +
115
+ + q_txt = query[:, :ref_start]
116
+ + q_ref = query[:, ref_start:ref_end]
117
+ + q_img = query[:, ref_end:]
118
+ +
119
+ + k_txt = key[:, :ref_start]
120
+ + k_ref = key[:, ref_start:ref_end]
121
+ + k_img = key[:, ref_end:]
122
+ +
123
+ + v_txt = value[:, :ref_start]
124
+ + v_ref = value[:, ref_start:ref_end]
125
+ + v_img = value[:, ref_end:]
126
+ +
127
+ + # txt+img attend to all tokens
128
+ + q_txt_img = torch.cat([q_txt, q_img], dim=1)
129
+ + k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
130
+ + v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
131
+ + attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
132
+ + attn_txt = attn_txt_img[:, :ref_start]
133
+ + attn_img = attn_txt_img[:, ref_start:]
134
+ +
135
+ + # ref tokens self-attend only
136
+ + attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
137
+ +
138
+ + return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
139
+ +
140
+ +
141
+ +def _blend_mod_params(
142
+ + img_params: tuple[torch.Tensor, ...],
143
+ + ref_params: tuple[torch.Tensor, ...],
144
+ + num_ref: int,
145
+ + seq_len: int,
146
+ +) -> tuple[torch.Tensor, ...]:
147
+ + """Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
148
+ + blended = []
149
+ + for im, rm in zip(img_params, ref_params):
150
+ + if im.ndim == 2:
151
+ + im = im.unsqueeze(1)
152
+ + rm = rm.unsqueeze(1)
153
+ + B = im.shape[0]
154
+ + blended.append(
155
+ + torch.cat(
156
+ + [rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
157
+ + dim=1,
158
+ + )
159
+ + )
160
+ + return tuple(blended)
161
+ +
162
+ +
163
+ +def _blend_double_block_mods(
164
+ + img_mod: torch.Tensor,
165
+ + ref_mod: torch.Tensor,
166
+ + num_ref: int,
167
+ + seq_len: int,
168
+ +) -> torch.Tensor:
169
+ + """Blend double-block image-stream modulations for a [ref, img] sequence layout.
170
+ +
171
+ + Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is
172
+ + compatible with `Flux2Modulation.split(mod, 2)`.
173
+ + """
174
+ + img_mods = Flux2Modulation.split(img_mod, 2)
175
+ + ref_mods = Flux2Modulation.split(ref_mod, 2)
176
+ +
177
+ + all_params = []
178
+ + for img_set, ref_set in zip(img_mods, ref_mods):
179
+ + blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
180
+ + all_params.extend(blended)
181
+ + return torch.cat(all_params, dim=-1)
182
+ +
183
+ +
184
+ +def _blend_single_block_mods(
185
+ + single_mod: torch.Tensor,
186
+ + ref_mod: torch.Tensor,
187
+ + num_txt: int,
188
+ + num_ref: int,
189
+ + seq_len: int,
190
+ +) -> torch.Tensor:
191
+ + """Blend single-block modulations for a [txt, ref, img] sequence layout.
192
+ +
193
+ + Takes raw modulation tensors and returns a blended raw tensor compatible with
194
+ + `Flux2Modulation.split(mod, 1)`.
195
+ + """
196
+ + img_params = Flux2Modulation.split(single_mod, 1)[0]
197
+ + ref_params = Flux2Modulation.split(ref_mod, 1)[0]
198
+ +
199
+ + blended = []
200
+ + for im, rm in zip(img_params, ref_params):
201
+ + if im.ndim == 2:
202
+ + im = im.unsqueeze(1)
203
+ + rm = rm.unsqueeze(1)
204
+ + B = im.shape[0]
205
+ + im_expanded = im.expand(B, seq_len, -1)
206
+ + rm_expanded = rm.expand(B, num_ref, -1)
207
+ + blended.append(
208
+ + torch.cat(
209
+ + [im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
210
+ + dim=1,
211
+ + )
212
+ + )
213
+ + return torch.cat(blended, dim=-1)
214
+ +
215
+ +
216
+ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
217
+ query = attn.to_q(hidden_states)
218
+ key = attn.to_k(hidden_states)
219
+ @@ -181,9 +368,105 @@ class Flux2AttnProcessor:
220
+ return hidden_states
221
+
222
+
223
+ +class Flux2KVAttnProcessor:
224
+ + """
225
+ + Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
226
+ +
227
+ + When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal
228
+ + attention is used (ref tokens self-attend only, txt+img attend to all).
229
+ + When `kv_cache_mode` is "cached", cached ref K/V are injected during attention.
230
+ + When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
231
+ + """
232
+ +
233
+ + _attention_backend = None
234
+ + _parallel_config = None
235
+ +
236
+ + def __init__(self):
237
+ + if not hasattr(F, "scaled_dot_product_attention"):
238
+ + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
239
+ +
240
+ + def __call__(
241
+ + self,
242
+ + attn: "Flux2Attention",
243
+ + hidden_states: torch.Tensor,
244
+ + encoder_hidden_states: torch.Tensor = None,
245
+ + attention_mask: torch.Tensor | None = None,
246
+ + image_rotary_emb: torch.Tensor | None = None,
247
+ + kv_cache: Flux2KVLayerCache | None = None,
248
+ + kv_cache_mode: str | None = None,
249
+ + num_ref_tokens: int = 0,
250
+ + ) -> torch.Tensor:
251
+ + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
252
+ + attn, hidden_states, encoder_hidden_states
253
+ + )
254
+ +
255
+ + query = query.unflatten(-1, (attn.heads, -1))
256
+ + key = key.unflatten(-1, (attn.heads, -1))
257
+ + value = value.unflatten(-1, (attn.heads, -1))
258
+ +
259
+ + query = attn.norm_q(query)
260
+ + key = attn.norm_k(key)
261
+ +
262
+ + if attn.added_kv_proj_dim is not None:
263
+ + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
264
+ + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
265
+ + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
266
+ +
267
+ + encoder_query = attn.norm_added_q(encoder_query)
268
+ + encoder_key = attn.norm_added_k(encoder_key)
269
+ +
270
+ + query = torch.cat([encoder_query, query], dim=1)
271
+ + key = torch.cat([encoder_key, key], dim=1)
272
+ + value = torch.cat([encoder_value, value], dim=1)
273
+ +
274
+ + if image_rotary_emb is not None:
275
+ + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
276
+ + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
277
+ +
278
+ + num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
279
+ +
280
+ + # Extract ref K/V from the combined sequence
281
+ + if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
282
+ + ref_start = num_txt_tokens
283
+ + ref_end = num_txt_tokens + num_ref_tokens
284
+ + kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
285
+ +
286
+ + # Dispatch attention
287
+ + if kv_cache_mode == "extract" and num_ref_tokens > 0:
288
+ + hidden_states = _flux2_kv_causal_attention(
289
+ + query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
290
+ + )
291
+ + elif kv_cache_mode == "cached" and kv_cache is not None:
292
+ + hidden_states = _flux2_kv_causal_attention(
293
+ + query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
294
+ + )
295
+ + else:
296
+ + hidden_states = dispatch_attention_fn(
297
+ + query, key, value, attn_mask=attention_mask,
298
+ + backend=self._attention_backend, parallel_config=self._parallel_config,
299
+ + )
300
+ +
301
+ + hidden_states = hidden_states.flatten(2, 3)
302
+ + hidden_states = hidden_states.to(query.dtype)
303
+ +
304
+ + if encoder_hidden_states is not None:
305
+ + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
306
+ + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
307
+ + )
308
+ + encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
309
+ +
310
+ + hidden_states = attn.to_out[0](hidden_states)
311
+ + hidden_states = attn.to_out[1](hidden_states)
312
+ +
313
+ + if encoder_hidden_states is not None:
314
+ + return hidden_states, encoder_hidden_states
315
+ + else:
316
+ + return hidden_states
317
+ +
318
+ +
319
+ class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
320
+ _default_processor_cls = Flux2AttnProcessor
321
+ - _available_processors = [Flux2AttnProcessor]
322
+ + _available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
323
+
324
+ def __init__(
325
+ self,
326
+ @@ -312,6 +595,86 @@ class Flux2ParallelSelfAttnProcessor:
327
+ return hidden_states
328
+
329
+
330
+ +class Flux2KVParallelSelfAttnProcessor:
331
+ + """
332
+ + Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
333
+ +
334
+ + When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used.
335
+ + When `kv_cache_mode` is "cached", cached ref K/V are injected during attention.
336
+ + When no KV args are provided, behaves identically to `Flux2ParallelSelfAttnProcessor`.
337
+ + """
338
+ +
339
+ + _attention_backend = None
340
+ + _parallel_config = None
341
+ +
342
+ + def __init__(self):
343
+ + if not hasattr(F, "scaled_dot_product_attention"):
344
+ + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
345
+ +
346
+ + def __call__(
347
+ + self,
348
+ + attn: "Flux2ParallelSelfAttention",
349
+ + hidden_states: torch.Tensor,
350
+ + attention_mask: torch.Tensor | None = None,
351
+ + image_rotary_emb: torch.Tensor | None = None,
352
+ + kv_cache: Flux2KVLayerCache | None = None,
353
+ + kv_cache_mode: str | None = None,
354
+ + num_txt_tokens: int = 0,
355
+ + num_ref_tokens: int = 0,
356
+ + ) -> torch.Tensor:
357
+ + # Parallel in (QKV + MLP in) projection
358
+ + hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
359
+ + qkv, mlp_hidden_states = torch.split(
360
+ + hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
361
+ + )
362
+ +
363
+ + query, key, value = qkv.chunk(3, dim=-1)
364
+ +
365
+ + query = query.unflatten(-1, (attn.heads, -1))
366
+ + key = key.unflatten(-1, (attn.heads, -1))
367
+ + value = value.unflatten(-1, (attn.heads, -1))
368
+ +
369
+ + query = attn.norm_q(query)
370
+ + key = attn.norm_k(key)
371
+ +
372
+ + if image_rotary_emb is not None:
373
+ + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
374
+ + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
375
+ +
376
+ + # Extract ref K/V from the combined sequence
377
+ + if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
378
+ + ref_start = num_txt_tokens
379
+ + ref_end = num_txt_tokens + num_ref_tokens
380
+ + kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
381
+ +
382
+ + # Dispatch attention
383
+ + if kv_cache_mode == "extract" and num_ref_tokens > 0:
384
+ + attn_output = _flux2_kv_causal_attention(
385
+ + query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
386
+ + )
387
+ + elif kv_cache_mode == "cached" and kv_cache is not None:
388
+ + attn_output = _flux2_kv_causal_attention(
389
+ + query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
390
+ + )
391
+ + else:
392
+ + attn_output = dispatch_attention_fn(
393
+ + query, key, value, attn_mask=attention_mask,
394
+ + backend=self._attention_backend, parallel_config=self._parallel_config,
395
+ + )
396
+ +
397
+ + attn_output = attn_output.flatten(2, 3)
398
+ + attn_output = attn_output.to(query.dtype)
399
+ +
400
+ + # Handle the feedforward (FF) logic
401
+ + mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
402
+ +
403
+ + # Concatenate and parallel output projection
404
+ + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
405
+ + hidden_states = attn.to_out(hidden_states)
406
+ +
407
+ + return hidden_states
408
+ +
409
+ +
410
+ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
411
+ """
412
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
413
+ @@ -322,7 +685,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
414
+ """
415
+
416
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
417
+ - _available_processors = [Flux2ParallelSelfAttnProcessor]
418
+ + _available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
419
+ # Does not support QKV fusion as the QKV projections are always fused
420
+ _supports_qkv_fusion = False
421
+
422
+ @@ -780,6 +1143,8 @@ class Flux2Transformer2DModel(
423
+
424
+ self.gradient_checkpointing = False
425
+
426
+ + _skip_keys = ["kv_cache"]
427
+ +
428
+ @apply_lora_scale("joint_attention_kwargs")
429
+ def forward(
430
+ self,
431
+ @@ -791,19 +1156,21 @@ class Flux2Transformer2DModel(
432
+ guidance: torch.Tensor = None,
433
+ joint_attention_kwargs: dict[str, Any] | None = None,
434
+ return_dict: bool = True,
435
+ + kv_cache: "Flux2KVCache | None" = None,
436
+ + kv_cache_mode: str | None = None,
437
+ + num_ref_tokens: int = 0,
438
+ + ref_fixed_timestep: float = 0.0,
439
+ ) -> torch.Tensor | Transformer2DModelOutput:
440
+ """
441
+ - The [`FluxTransformer2DModel`] forward method.
442
+ + The [`Flux2Transformer2DModel`] forward method.
443
+
444
+ Args:
445
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
446
+ Input `hidden_states`.
447
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
448
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
449
+ - timestep ( `torch.LongTensor`):
450
+ + timestep (`torch.LongTensor`):
451
+ Used to indicate denoising step.
452
+ - block_controlnet_hidden_states: (`list` of `torch.Tensor`):
453
+ - A list of tensors that if specified are added to the residuals of transformer blocks.
454
+ joint_attention_kwargs (`dict`, *optional*):
455
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
456
+ `self.processor` in
457
+ @@ -811,13 +1178,23 @@ class Flux2Transformer2DModel(
458
+ return_dict (`bool`, *optional*, defaults to `True`):
459
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
460
+ tuple.
461
+ + kv_cache (`Flux2KVCache`, *optional*):
462
+ + KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created
463
+ + and returned. When "cached", the provided cache is used to inject ref K/V during attention.
464
+ + kv_cache_mode (`str`, *optional*):
465
+ + One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V).
466
+ + When `None`, standard forward pass without KV caching.
467
+ + num_ref_tokens (`int`, defaults to `0`):
468
+ + Number of reference image tokens prepended to `hidden_states` (only used when
469
+ + `kv_cache_mode="extract"`).
470
+ + ref_fixed_timestep (`float`, defaults to `0.0`):
471
+ + Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
472
+
473
+ Returns:
474
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
475
+ - `tuple` where the first element is the sample tensor.
476
+ + `tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
477
+ + populated `Flux2KVCache`.
478
+ """
479
+ - # 0. Handle input arguments
480
+ -
481
+ num_txt_tokens = encoder_hidden_states.shape[1]
482
+
483
+ # 1. Calculate timestep embedding and modulation parameters
484
+ @@ -832,13 +1209,33 @@ class Flux2Transformer2DModel(
485
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
486
+ single_stream_mod = self.single_stream_modulation(temb)
487
+
488
+ + # KV extract mode: create cache and blend modulations for ref tokens
489
+ + if kv_cache_mode == "extract" and num_ref_tokens > 0:
490
+ + num_img_tokens = hidden_states.shape[1] # includes ref tokens
491
+ +
492
+ + kv_cache = Flux2KVCache(
493
+ + num_double_layers=len(self.transformer_blocks),
494
+ + num_single_layers=len(self.single_transformer_blocks),
495
+ + )
496
+ + kv_cache.num_ref_tokens = num_ref_tokens
497
+ +
498
+ + # Ref tokens use a fixed timestep for modulation
499
+ + ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
500
+ + ref_temb = self.time_guidance_embed(ref_timestep, guidance)
501
+ +
502
+ + ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
503
+ + ref_single_mod = self.single_stream_modulation(ref_temb)
504
+ +
505
+ + # Blend double block img modulation: [ref_mod, img_mod]
506
+ + double_stream_mod_img = _blend_double_block_mods(
507
+ + double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
508
+ + )
509
+ +
510
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
511
+ hidden_states = self.x_embedder(hidden_states)
512
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
513
+
514
+ # 3. Calculate RoPE embeddings from image and text tokens
515
+ - # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
516
+ - # text prompts of differents lengths. Is this a use case we want to support?
517
+ if img_ids.ndim == 3:
518
+ img_ids = img_ids[0]
519
+ if txt_ids.ndim == 3:
520
+ @@ -851,8 +1248,29 @@ class Flux2Transformer2DModel(
521
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
522
+ )
523
+
524
+ - # 4. Double Stream Transformer Blocks
525
+ + # 4. Build joint_attention_kwargs with KV cache info
526
+ + if kv_cache_mode == "extract":
527
+ + kv_attn_kwargs = {
528
+ + **(joint_attention_kwargs or {}),
529
+ + "kv_cache": None,
530
+ + "kv_cache_mode": "extract",
531
+ + "num_ref_tokens": num_ref_tokens,
532
+ + }
533
+ + elif kv_cache_mode == "cached" and kv_cache is not None:
534
+ + kv_attn_kwargs = {
535
+ + **(joint_attention_kwargs or {}),
536
+ + "kv_cache": None,
537
+ + "kv_cache_mode": "cached",
538
+ + "num_ref_tokens": kv_cache.num_ref_tokens,
539
+ + }
540
+ + else:
541
+ + kv_attn_kwargs = joint_attention_kwargs
542
+ +
543
+ + # 5. Double Stream Transformer Blocks
544
+ for index_block, block in enumerate(self.transformer_blocks):
545
+ + if kv_cache_mode is not None and kv_cache is not None:
546
+ + kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
547
+ +
548
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
549
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
550
+ block,
551
+ @@ -861,7 +1279,7 @@ class Flux2Transformer2DModel(
552
+ double_stream_mod_img,
553
+ double_stream_mod_txt,
554
+ concat_rotary_emb,
555
+ - joint_attention_kwargs,
556
+ + kv_attn_kwargs,
557
+ )
558
+ else:
559
+ encoder_hidden_states, hidden_states = block(
560
+ @@ -870,13 +1288,30 @@ class Flux2Transformer2DModel(
561
+ temb_mod_img=double_stream_mod_img,
562
+ temb_mod_txt=double_stream_mod_txt,
563
+ image_rotary_emb=concat_rotary_emb,
564
+ - joint_attention_kwargs=joint_attention_kwargs,
565
+ + joint_attention_kwargs=kv_attn_kwargs,
566
+ )
567
+ +
568
+ # Concatenate text and image streams for single-block inference
569
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
570
+
571
+ - # 5. Single Stream Transformer Blocks
572
+ + # Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
573
+ + if kv_cache_mode == "extract" and num_ref_tokens > 0:
574
+ + total_single_len = hidden_states.shape[1]
575
+ + single_stream_mod = _blend_single_block_mods(
576
+ + single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
577
+ + )
578
+ +
579
+ + # Build single-block KV kwargs (single blocks need num_txt_tokens)
580
+ + if kv_cache_mode is not None:
581
+ + kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
582
+ + else:
583
+ + kv_attn_kwargs_single = kv_attn_kwargs
584
+ +
585
+ + # 6. Single Stream Transformer Blocks
586
+ for index_block, block in enumerate(self.single_transformer_blocks):
587
+ + if kv_cache_mode is not None and kv_cache is not None:
588
+ + kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
589
+ +
590
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
591
+ hidden_states = self._gradient_checkpointing_func(
592
+ block,
593
+ @@ -884,7 +1319,7 @@ class Flux2Transformer2DModel(
594
+ None,
595
+ single_stream_mod,
596
+ concat_rotary_emb,
597
+ - joint_attention_kwargs,
598
+ + kv_attn_kwargs_single,
599
+ )
600
+ else:
601
+ hidden_states = block(
602
+ @@ -892,15 +1327,24 @@ class Flux2Transformer2DModel(
603
+ encoder_hidden_states=None,
604
+ temb_mod=single_stream_mod,
605
+ image_rotary_emb=concat_rotary_emb,
606
+ - joint_attention_kwargs=joint_attention_kwargs,
607
+ + joint_attention_kwargs=kv_attn_kwargs_single,
608
+ )
609
+ - # Remove text tokens from concatenated stream
610
+ - hidden_states = hidden_states[:, num_txt_tokens:, ...]
611
+
612
+ - # 6. Output layers
613
+ + # Remove text tokens (and ref tokens in extract mode) from concatenated stream
614
+ + if kv_cache_mode == "extract" and num_ref_tokens > 0:
615
+ + hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
616
+ + else:
617
+ + hidden_states = hidden_states[:, num_txt_tokens:, ...]
618
+ +
619
+ + # 7. Output layers
620
+ hidden_states = self.norm_out(hidden_states, temb)
621
+ output = self.proj_out(hidden_states)
622
+
623
+ + if kv_cache_mode == "extract":
624
+ + if not return_dict:
625
+ + return (output,), kv_cache
626
+ + return Transformer2DModelOutput(sample=output), kv_cache
627
+ +
628
+ if not return_dict:
629
+ return (output,)
630
+
631
+ diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
632
+ index 800703533..b9596f4b7 100644
633
+ --- a/src/diffusers/pipelines/__init__.py
634
+ +++ b/src/diffusers/pipelines/__init__.py
635
+ @@ -129,7 +129,7 @@ else:
636
+ ]
637
+ _import_structure["bria"] = ["BriaPipeline"]
638
+ _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
639
+ - _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
640
+ + _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
641
+ _import_structure["flux"] = [
642
+ "FluxControlPipeline",
643
+ "FluxControlInpaintPipeline",
644
+ @@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
645
+ FluxPriorReduxPipeline,
646
+ ReduxImageEncoder,
647
+ )
648
+ - from .flux2 import Flux2KleinPipeline, Flux2Pipeline
649
+ + from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
650
+ from .glm_image import GlmImagePipeline
651
+ from .helios import HeliosPipeline, HeliosPyramidPipeline
652
+ from .hidream_image import HiDreamImagePipeline
653
+ diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py
654
+ index f6e1d5206..52a8f464b 100644
655
+ --- a/src/diffusers/pipelines/flux2/__init__.py
656
+ +++ b/src/diffusers/pipelines/flux2/__init__.py
657
+ @@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
658
+ else:
659
+ _import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
660
+ _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
661
+ + _import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
662
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
663
+ try:
664
+ if not (is_transformers_available() and is_torch_available()):
665
+ @@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
666
+ else:
667
+ from .pipeline_flux2 import Flux2Pipeline
668
+ from .pipeline_flux2_klein import Flux2KleinPipeline
669
+ + from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
670
+ else:
671
+ import sys
672
+
673
+ diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
674
+ new file mode 100644
675
+ index 000000000..cac29f621
676
+ --- /dev/null
677
+ +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
678
+ @@ -0,0 +1,887 @@
679
+ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
680
+ +#
681
+ +# Licensed under the Apache License, Version 2.0 (the "License");
682
+ +# you may not use this file except in compliance with the License.
683
+ +# You may obtain a copy of the License at
684
+ +#
685
+ +# http://www.apache.org/licenses/LICENSE-2.0
686
+ +#
687
+ +# Unless required by applicable law or agreed to in writing, software
688
+ +# distributed under the License is distributed on an "AS IS" BASIS,
689
+ +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
690
+ +# See the License for the specific language governing permissions and
691
+ +# limitations under the License.
692
+ +
693
+ +import inspect
694
+ +from typing import Any, Callable
695
+ +
696
+ +import numpy as np
697
+ +import PIL
698
+ +import torch
699
+ +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
700
+ +
701
+ +from ...loaders import Flux2LoraLoaderMixin
702
+ +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
703
+ +from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
704
+ +from ...schedulers import FlowMatchEulerDiscreteScheduler
705
+ +from ...utils import is_torch_xla_available, logging, replace_example_docstring
706
+ +from ...utils.torch_utils import randn_tensor
707
+ +from ..pipeline_utils import DiffusionPipeline
708
+ +from .image_processor import Flux2ImageProcessor
709
+ +from .pipeline_output import Flux2PipelineOutput
710
+ +
711
+ +
712
+ +if is_torch_xla_available():
713
+ + import torch_xla.core.xla_model as xm
714
+ +
715
+ + XLA_AVAILABLE = True
716
+ +else:
717
+ + XLA_AVAILABLE = False
718
+ +
719
+ +
720
+ +logger = logging.get_logger(__name__) # pylint: disable=invalid-name
721
+ +
722
+ +EXAMPLE_DOC_STRING = """
723
+ + Examples:
724
+ + ```py
725
+ + >>> import torch
726
+ + >>> from PIL import Image
727
+ + >>> from diffusers import Flux2KleinKVPipeline
728
+ +
729
+ + >>> pipe = Flux2KleinKVPipeline.from_pretrained(
730
+ + ... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
731
+ + ... )
732
+ + >>> pipe.to("cuda")
733
+ + >>> ref_image = Image.open("reference.png")
734
+ + >>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
735
+ + >>> image.save("flux2_kv_output.png")
736
+ + ```
737
+ +"""
738
+ +
739
+ +
740
+ +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
741
+ +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
742
+ + a1, b1 = 8.73809524e-05, 1.89833333
743
+ + a2, b2 = 0.00016927, 0.45666666
744
+ +
745
+ + if image_seq_len > 4300:
746
+ + mu = a2 * image_seq_len + b2
747
+ + return float(mu)
748
+ +
749
+ + m_200 = a2 * image_seq_len + b2
750
+ + m_10 = a1 * image_seq_len + b1
751
+ +
752
+ + a = (m_200 - m_10) / 190.0
753
+ + b = m_200 - 200.0 * a
754
+ + mu = a * num_steps + b
755
+ +
756
+ + return float(mu)
757
+ +
758
+ +
759
+ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
760
+ +def retrieve_timesteps(
761
+ + scheduler,
762
+ + num_inference_steps: int | None = None,
763
+ + device: str | torch.device | None = None,
764
+ + timesteps: list[int] | None = None,
765
+ + sigmas: list[float] | None = None,
766
+ + **kwargs,
767
+ +):
768
+ + r"""
769
+ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
770
+ + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
771
+ +
772
+ + Args:
773
+ + scheduler (`SchedulerMixin`):
774
+ + The scheduler to get timesteps from.
775
+ + num_inference_steps (`int`):
776
+ + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
777
+ + must be `None`.
778
+ + device (`str` or `torch.device`, *optional*):
779
+ + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
780
+ + timesteps (`list[int]`, *optional*):
781
+ + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
782
+ + `num_inference_steps` and `sigmas` must be `None`.
783
+ + sigmas (`list[float]`, *optional*):
784
+ + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
785
+ + `num_inference_steps` and `timesteps` must be `None`.
786
+ +
787
+ + Returns:
788
+ + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
789
+ + second element is the number of inference steps.
790
+ + """
791
+ + if timesteps is not None and sigmas is not None:
792
+ + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
793
+ + if timesteps is not None:
794
+ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
795
+ + if not accepts_timesteps:
796
+ + raise ValueError(
797
+ + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
798
+ + f" timestep schedules. Please check whether you are using the correct scheduler."
799
+ + )
800
+ + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
801
+ + timesteps = scheduler.timesteps
802
+ + num_inference_steps = len(timesteps)
803
+ + elif sigmas is not None:
804
+ + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
805
+ + if not accept_sigmas:
806
+ + raise ValueError(
807
+ + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
808
+ + f" sigmas schedules. Please check whether you are using the correct scheduler."
809
+ + )
810
+ + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
811
+ + timesteps = scheduler.timesteps
812
+ + num_inference_steps = len(timesteps)
813
+ + else:
814
+ + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
815
+ + timesteps = scheduler.timesteps
816
+ + return timesteps, num_inference_steps
817
+ +
818
+ +
819
+ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
820
+ +def retrieve_latents(
821
+ + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
822
+ +):
823
+ + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
824
+ + return encoder_output.latent_dist.sample(generator)
825
+ + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
826
+ + return encoder_output.latent_dist.mode()
827
+ + elif hasattr(encoder_output, "latents"):
828
+ + return encoder_output.latents
829
+ + else:
830
+ + raise AttributeError("Could not access latents of provided encoder_output")
831
+ +
832
+ +
833
+ +class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
834
+ + r"""
835
+ + The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
836
+ +
837
+ + On the first denoising step, reference image tokens are included in the forward pass and their attention
838
+ + K/V projections are cached. On subsequent steps, the cached K/V are reused without recomputing,
839
+ + providing faster inference when using reference images.
840
+ +
841
+ + Reference:
842
+ + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
843
+ +
844
+ + Args:
845
+ + transformer ([`Flux2Transformer2DModel`]):
846
+ + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
847
+ + scheduler ([`FlowMatchEulerDiscreteScheduler`]):
848
+ + A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
849
+ + vae ([`AutoencoderKLFlux2`]):
850
+ + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
851
+ + text_encoder ([`Qwen3ForCausalLM`]):
852
+ + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
853
+ + tokenizer (`Qwen2TokenizerFast`):
854
+ + Tokenizer of class
855
+ + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
856
+ + """
857
+ +
858
+ + model_cpu_offload_seq = "text_encoder->transformer->vae"
859
+ + _callback_tensor_inputs = ["latents", "prompt_embeds"]
860
+ +
861
+ + def __init__(
862
+ + self,
863
+ + scheduler: FlowMatchEulerDiscreteScheduler,
864
+ + vae: AutoencoderKLFlux2,
865
+ + text_encoder: Qwen3ForCausalLM,
866
+ + tokenizer: Qwen2TokenizerFast,
867
+ + transformer: Flux2Transformer2DModel,
868
+ + is_distilled: bool = True,
869
+ + ):
870
+ + super().__init__()
871
+ +
872
+ + self.register_modules(
873
+ + vae=vae,
874
+ + text_encoder=text_encoder,
875
+ + tokenizer=tokenizer,
876
+ + scheduler=scheduler,
877
+ + transformer=transformer,
878
+ + )
879
+ +
880
+ + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
881
+ + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
882
+ + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
883
+ + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
884
+ + self.tokenizer_max_length = 512
885
+ + self.default_sample_size = 128
886
+ +
887
+ + # Set KV-cache-aware attention processors
888
+ + self._set_kv_attn_processors()
889
+ +
890
+ + @staticmethod
891
+ + def _get_qwen3_prompt_embeds(
892
+ + text_encoder: Qwen3ForCausalLM,
893
+ + tokenizer: Qwen2TokenizerFast,
894
+ + prompt: str | list[str],
895
+ + dtype: torch.dtype | None = None,
896
+ + device: torch.device | None = None,
897
+ + max_sequence_length: int = 512,
898
+ + hidden_states_layers: list[int] = (9, 18, 27),
899
+ + ):
900
+ + dtype = text_encoder.dtype if dtype is None else dtype
901
+ + device = text_encoder.device if device is None else device
902
+ +
903
+ + prompt = [prompt] if isinstance(prompt, str) else prompt
904
+ +
905
+ + all_input_ids = []
906
+ + all_attention_masks = []
907
+ +
908
+ + for single_prompt in prompt:
909
+ + messages = [{"role": "user", "content": single_prompt}]
910
+ + text = tokenizer.apply_chat_template(
911
+ + messages,
912
+ + tokenize=False,
913
+ + add_generation_prompt=True,
914
+ + enable_thinking=False,
915
+ + )
916
+ + inputs = tokenizer(
917
+ + text,
918
+ + return_tensors="pt",
919
+ + padding="max_length",
920
+ + truncation=True,
921
+ + max_length=max_sequence_length,
922
+ + )
923
+ +
924
+ + all_input_ids.append(inputs["input_ids"])
925
+ + all_attention_masks.append(inputs["attention_mask"])
926
+ +
927
+ + input_ids = torch.cat(all_input_ids, dim=0).to(device)
928
+ + attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
929
+ +
930
+ + # Forward pass through the model
931
+ + output = text_encoder(
932
+ + input_ids=input_ids,
933
+ + attention_mask=attention_mask,
934
+ + output_hidden_states=True,
935
+ + use_cache=False,
936
+ + )
937
+ +
938
+ + # Only use outputs from intermediate layers and stack them
939
+ + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
940
+ + out = out.to(dtype=dtype, device=device)
941
+ +
942
+ + batch_size, num_channels, seq_len, hidden_dim = out.shape
943
+ + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
944
+ +
945
+ + return prompt_embeds
946
+ +
947
+ + @staticmethod
948
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
949
+ + def _prepare_text_ids(
950
+ + x: torch.Tensor, # (B, L, D) or (L, D)
951
+ + t_coord: torch.Tensor | None = None,
952
+ + ):
953
+ + B, L, _ = x.shape
954
+ + out_ids = []
955
+ +
956
+ + for i in range(B):
957
+ + t = torch.arange(1) if t_coord is None else t_coord[i]
958
+ + h = torch.arange(1)
959
+ + w = torch.arange(1)
960
+ + l = torch.arange(L)
961
+ +
962
+ + coords = torch.cartesian_prod(t, h, w, l)
963
+ + out_ids.append(coords)
964
+ +
965
+ + return torch.stack(out_ids)
966
+ +
967
+ + @staticmethod
968
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
969
+ + def _prepare_latent_ids(
970
+ + latents: torch.Tensor, # (B, C, H, W)
971
+ + ):
972
+ + r"""
973
+ + Generates 4D position coordinates (T, H, W, L) for latent tensors.
974
+ +
975
+ + Args:
976
+ + latents (torch.Tensor):
977
+ + Latent tensor of shape (B, C, H, W)
978
+ +
979
+ + Returns:
980
+ + torch.Tensor:
981
+ + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
982
+ + H=[0..H-1], W=[0..W-1], L=0
983
+ + """
984
+ +
985
+ + batch_size, _, height, width = latents.shape
986
+ +
987
+ + t = torch.arange(1) # [0] - time dimension
988
+ + h = torch.arange(height)
989
+ + w = torch.arange(width)
990
+ + l = torch.arange(1) # [0] - layer dimension
991
+ +
992
+ + # Create position IDs: (H*W, 4)
993
+ + latent_ids = torch.cartesian_prod(t, h, w, l)
994
+ +
995
+ + # Expand to batch: (B, H*W, 4)
996
+ + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
997
+ +
998
+ + return latent_ids
999
+ +
1000
+ + @staticmethod
1001
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
1002
+ + def _prepare_image_ids(
1003
+ + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
1004
+ + scale: int = 10,
1005
+ + ):
1006
+ + r"""
1007
+ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
1008
+ +
1009
+ + This function creates a unique coordinate for every pixel/patch across all input latent with different
1010
+ + dimensions.
1011
+ +
1012
+ + Args:
1013
+ + image_latents (list[torch.Tensor]):
1014
+ + A list of image latent feature tensors, typically of shape (C, H, W).
1015
+ + scale (int, optional):
1016
+ + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
1017
+ + latent is: 'scale + scale * i'. Defaults to 10.
1018
+ +
1019
+ + Returns:
1020
+ + torch.Tensor:
1021
+ + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
1022
+ + input latents.
1023
+ +
1024
+ + Coordinate Components (Dimension 4):
1025
+ + - T (Time): The unique index indicating which latent image the coordinate belongs to.
1026
+ + - H (Height): The row index within that latent image.
1027
+ + - W (Width): The column index within that latent image.
1028
+ + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
1029
+ + """
1030
+ +
1031
+ + if not isinstance(image_latents, list):
1032
+ + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
1033
+ +
1034
+ + # create time offset for each reference image
1035
+ + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
1036
+ + t_coords = [t.view(-1) for t in t_coords]
1037
+ +
1038
+ + image_latent_ids = []
1039
+ + for x, t in zip(image_latents, t_coords):
1040
+ + x = x.squeeze(0)
1041
+ + _, height, width = x.shape
1042
+ +
1043
+ + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
1044
+ + image_latent_ids.append(x_ids)
1045
+ +
1046
+ + image_latent_ids = torch.cat(image_latent_ids, dim=0)
1047
+ + image_latent_ids = image_latent_ids.unsqueeze(0)
1048
+ +
1049
+ + return image_latent_ids
1050
+ +
1051
+ + @staticmethod
1052
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
1053
+ + def _patchify_latents(latents):
1054
+ + batch_size, num_channels_latents, height, width = latents.shape
1055
+ + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
1056
+ + latents = latents.permute(0, 1, 3, 5, 2, 4)
1057
+ + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
1058
+ + return latents
1059
+ +
1060
+ + @staticmethod
1061
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
1062
+ + def _unpatchify_latents(latents):
1063
+ + batch_size, num_channels_latents, height, width = latents.shape
1064
+ + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
1065
+ + latents = latents.permute(0, 1, 4, 2, 5, 3)
1066
+ + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
1067
+ + return latents
1068
+ +
1069
+ + @staticmethod
1070
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
1071
+ + def _pack_latents(latents):
1072
+ + """
1073
+ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
1074
+ + """
1075
+ +
1076
+ + batch_size, num_channels, height, width = latents.shape
1077
+ + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
1078
+ +
1079
+ + return latents
1080
+ +
1081
+ + @staticmethod
1082
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
1083
+ + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
1084
+ + """
1085
+ + using position ids to scatter tokens into place
1086
+ + """
1087
+ + x_list = []
1088
+ + for data, pos in zip(x, x_ids):
1089
+ + _, ch = data.shape # noqa: F841
1090
+ + h_ids = pos[:, 1].to(torch.int64)
1091
+ + w_ids = pos[:, 2].to(torch.int64)
1092
+ +
1093
+ + h = torch.max(h_ids) + 1
1094
+ + w = torch.max(w_ids) + 1
1095
+ +
1096
+ + flat_ids = h_ids * w + w_ids
1097
+ +
1098
+ + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
1099
+ + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
1100
+ +
1101
+ + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
1102
+ +
1103
+ + out = out.view(h, w, ch).permute(2, 0, 1)
1104
+ + x_list.append(out)
1105
+ +
1106
+ + return torch.stack(x_list, dim=0)
1107
+ +
1108
+ + def _set_kv_attn_processors(self):
1109
+ + """Replace default attention processors with KV-cache-aware variants."""
1110
+ + for block in self.transformer.transformer_blocks:
1111
+ + block.attn.set_processor(Flux2KVAttnProcessor())
1112
+ + for block in self.transformer.single_transformer_blocks:
1113
+ + block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
1114
+ +
1115
+ + def encode_prompt(
1116
+ + self,
1117
+ + prompt: str | list[str],
1118
+ + device: torch.device | None = None,
1119
+ + num_images_per_prompt: int = 1,
1120
+ + prompt_embeds: torch.Tensor | None = None,
1121
+ + max_sequence_length: int = 512,
1122
+ + text_encoder_out_layers: tuple[int] = (9, 18, 27),
1123
+ + ):
1124
+ + device = device or self._execution_device
1125
+ +
1126
+ + if prompt is None:
1127
+ + prompt = ""
1128
+ +
1129
+ + prompt = [prompt] if isinstance(prompt, str) else prompt
1130
+ +
1131
+ + if prompt_embeds is None:
1132
+ + prompt_embeds = self._get_qwen3_prompt_embeds(
1133
+ + text_encoder=self.text_encoder,
1134
+ + tokenizer=self.tokenizer,
1135
+ + prompt=prompt,
1136
+ + device=device,
1137
+ + max_sequence_length=max_sequence_length,
1138
+ + hidden_states_layers=text_encoder_out_layers,
1139
+ + )
1140
+ +
1141
+ + batch_size, seq_len, _ = prompt_embeds.shape
1142
+ + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1143
+ + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
1144
+ +
1145
+ + text_ids = self._prepare_text_ids(prompt_embeds)
1146
+ + text_ids = text_ids.to(device)
1147
+ + return prompt_embeds, text_ids
1148
+ +
1149
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
1150
+ + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
1151
+ + if image.ndim != 4:
1152
+ + raise ValueError(f"Expected image dims 4, got {image.ndim}.")
1153
+ +
1154
+ + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
1155
+ + image_latents = self._patchify_latents(image_latents)
1156
+ +
1157
+ + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
1158
+ + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
1159
+ + image_latents = (image_latents - latents_bn_mean) / latents_bn_std
1160
+ +
1161
+ + return image_latents
1162
+ +
1163
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
1164
+ + def prepare_latents(
1165
+ + self,
1166
+ + batch_size,
1167
+ + num_latents_channels,
1168
+ + height,
1169
+ + width,
1170
+ + dtype,
1171
+ + device,
1172
+ + generator: torch.Generator,
1173
+ + latents: torch.Tensor | None = None,
1174
+ + ):
1175
+ + # VAE applies 8x compression on images but we must also account for packing which requires
1176
+ + # latent height and width to be divisible by 2.
1177
+ + height = 2 * (int(height) // (self.vae_scale_factor * 2))
1178
+ + width = 2 * (int(width) // (self.vae_scale_factor * 2))
1179
+ +
1180
+ + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
1181
+ + if isinstance(generator, list) and len(generator) != batch_size:
1182
+ + raise ValueError(
1183
+ + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1184
+ + f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1185
+ + )
1186
+ + if latents is None:
1187
+ + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1188
+ + else:
1189
+ + latents = latents.to(device=device, dtype=dtype)
1190
+ +
1191
+ + latent_ids = self._prepare_latent_ids(latents)
1192
+ + latent_ids = latent_ids.to(device)
1193
+ +
1194
+ + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
1195
+ + return latents, latent_ids
1196
+ +
1197
+ + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
1198
+ + def prepare_image_latents(
1199
+ + self,
1200
+ + images: list[torch.Tensor],
1201
+ + batch_size,
1202
+ + generator: torch.Generator,
1203
+ + device,
1204
+ + dtype,
1205
+ + ):
1206
+ + image_latents = []
1207
+ + for image in images:
1208
+ + image = image.to(device=device, dtype=dtype)
1209
+ + imagge_latent = self._encode_vae_image(image=image, generator=generator)
1210
+ + image_latents.append(imagge_latent) # (1, 128, 32, 32)
1211
+ +
1212
+ + image_latent_ids = self._prepare_image_ids(image_latents)
1213
+ +
1214
+ + # Pack each latent and concatenate
1215
+ + packed_latents = []
1216
+ + for latent in image_latents:
1217
+ + # latent: (1, 128, 32, 32)
1218
+ + packed = self._pack_latents(latent) # (1, 1024, 128)
1219
+ + packed = packed.squeeze(0) # (1024, 128) - remove batch dim
1220
+ + packed_latents.append(packed)
1221
+ +
1222
+ + # Concatenate all reference tokens along sequence dimension
1223
+ + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
1224
+ + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
1225
+ +
1226
+ + image_latents = image_latents.repeat(batch_size, 1, 1)
1227
+ + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
1228
+ + image_latent_ids = image_latent_ids.to(device)
1229
+ +
1230
+ + return image_latents, image_latent_ids
1231
+ +
1232
+ + def check_inputs(
1233
+ + self,
1234
+ + prompt,
1235
+ + height,
1236
+ + width,
1237
+ + prompt_embeds=None,
1238
+ + callback_on_step_end_tensor_inputs=None,
1239
+ + ):
1240
+ + if (
1241
+ + height is not None
1242
+ + and height % (self.vae_scale_factor * 2) != 0
1243
+ + or width is not None
1244
+ + and width % (self.vae_scale_factor * 2) != 0
1245
+ + ):
1246
+ + logger.warning(
1247
+ + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
1248
+ + )
1249
+ +
1250
+ + if callback_on_step_end_tensor_inputs is not None and not all(
1251
+ + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
1252
+ + ):
1253
+ + raise ValueError(
1254
+ + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
1255
+ + )
1256
+ +
1257
+ + if prompt is not None and prompt_embeds is not None:
1258
+ + raise ValueError(
1259
+ + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
1260
+ + " only forward one of the two."
1261
+ + )
1262
+ + elif prompt is None and prompt_embeds is None:
1263
+ + raise ValueError(
1264
+ + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
1265
+ + )
1266
+ + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
1267
+ + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
1268
+ +
1269
+ + @property
1270
+ + def attention_kwargs(self):
1271
+ + return self._attention_kwargs
1272
+ +
1273
+ + @property
1274
+ + def num_timesteps(self):
1275
+ + return self._num_timesteps
1276
+ +
1277
+ + @property
1278
+ + def current_timestep(self):
1279
+ + return self._current_timestep
1280
+ +
1281
+ + @property
1282
+ + def interrupt(self):
1283
+ + return self._interrupt
1284
+ +
1285
+ + @torch.no_grad()
1286
+ + @replace_example_docstring(EXAMPLE_DOC_STRING)
1287
+ + def __call__(
1288
+ + self,
1289
+ + image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
1290
+ + prompt: str | list[str] = None,
1291
+ + height: int | None = None,
1292
+ + width: int | None = None,
1293
+ + num_inference_steps: int = 4,
1294
+ + sigmas: list[float] | None = None,
1295
+ + num_images_per_prompt: int = 1,
1296
+ + generator: torch.Generator | list[torch.Generator] | None = None,
1297
+ + latents: torch.Tensor | None = None,
1298
+ + prompt_embeds: torch.Tensor | None = None,
1299
+ + output_type: str = "pil",
1300
+ + return_dict: bool = True,
1301
+ + attention_kwargs: dict[str, Any] | None = None,
1302
+ + callback_on_step_end: Callable[[int, int, dict], None] | None = None,
1303
+ + callback_on_step_end_tensor_inputs: list[str] = ["latents"],
1304
+ + max_sequence_length: int = 512,
1305
+ + text_encoder_out_layers: tuple[int] = (9, 18, 27),
1306
+ + ):
1307
+ + r"""
1308
+ + Function invoked when calling the pipeline for generation.
1309
+ +
1310
+ + Args:
1311
+ + image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
1312
+ + Reference image(s) for conditioning. On the first denoising step, reference tokens are included
1313
+ + in the forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are
1314
+ + reused without recomputing.
1315
+ + prompt (`str` or `List[str]`, *optional*):
1316
+ + The prompt or prompts to guide the image generation.
1317
+ + height (`int`, *optional*):
1318
+ + The height in pixels of the generated image.
1319
+ + width (`int`, *optional*):
1320
+ + The width in pixels of the generated image.
1321
+ + num_inference_steps (`int`, *optional*, defaults to 4):
1322
+ + The number of denoising steps.
1323
+ + sigmas (`List[float]`, *optional*):
1324
+ + Custom sigmas for the denoising schedule.
1325
+ + num_images_per_prompt (`int`, *optional*, defaults to 1):
1326
+ + The number of images to generate per prompt.
1327
+ + generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1328
+ + Generator(s) for deterministic generation.
1329
+ + latents (`torch.Tensor`, *optional*):
1330
+ + Pre-generated noisy latents.
1331
+ + prompt_embeds (`torch.Tensor`, *optional*):
1332
+ + Pre-generated text embeddings.
1333
+ + output_type (`str`, *optional*, defaults to `"pil"`):
1334
+ + Output format: `"pil"` or `"np"`.
1335
+ + return_dict (`bool`, *optional*, defaults to `True`):
1336
+ + Whether to return a `Flux2PipelineOutput` or a plain tuple.
1337
+ + attention_kwargs (`dict`, *optional*):
1338
+ + Extra kwargs passed to attention processors.
1339
+ + callback_on_step_end (`Callable`, *optional*):
1340
+ + Callback function called at the end of each denoising step.
1341
+ + callback_on_step_end_tensor_inputs (`List`, *optional*):
1342
+ + Tensor inputs for the callback function.
1343
+ + max_sequence_length (`int`, defaults to 512):
1344
+ + Maximum sequence length for the prompt.
1345
+ + text_encoder_out_layers (`tuple[int]`):
1346
+ + Layer indices for text encoder hidden state extraction.
1347
+ +
1348
+ + Examples:
1349
+ +
1350
+ + Returns:
1351
+ + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
1352
+ + """
1353
+ +
1354
+ + # 1. Check inputs
1355
+ + self.check_inputs(
1356
+ + prompt=prompt,
1357
+ + height=height,
1358
+ + width=width,
1359
+ + prompt_embeds=prompt_embeds,
1360
+ + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1361
+ + )
1362
+ +
1363
+ + self._attention_kwargs = attention_kwargs
1364
+ + self._current_timestep = None
1365
+ + self._interrupt = False
1366
+ +
1367
+ + # 2. Define call parameters
1368
+ + if prompt is not None and isinstance(prompt, str):
1369
+ + batch_size = 1
1370
+ + elif prompt is not None and isinstance(prompt, list):
1371
+ + batch_size = len(prompt)
1372
+ + else:
1373
+ + batch_size = prompt_embeds.shape[0]
1374
+ +
1375
+ + device = self._execution_device
1376
+ +
1377
+ + # 3. prepare text embeddings
1378
+ + prompt_embeds, text_ids = self.encode_prompt(
1379
+ + prompt=prompt,
1380
+ + prompt_embeds=prompt_embeds,
1381
+ + device=device,
1382
+ + num_images_per_prompt=num_images_per_prompt,
1383
+ + max_sequence_length=max_sequence_length,
1384
+ + text_encoder_out_layers=text_encoder_out_layers,
1385
+ + )
1386
+ +
1387
+ + # 4. process images
1388
+ + if image is not None and not isinstance(image, list):
1389
+ + image = [image]
1390
+ +
1391
+ + condition_images = None
1392
+ + if image is not None:
1393
+ + for img in image:
1394
+ + self.image_processor.check_image_input(img)
1395
+ +
1396
+ + condition_images = []
1397
+ + for img in image:
1398
+ + image_width, image_height = img.size
1399
+ + if image_width * image_height > 1024 * 1024:
1400
+ + img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
1401
+ + image_width, image_height = img.size
1402
+ +
1403
+ + multiple_of = self.vae_scale_factor * 2
1404
+ + image_width = (image_width // multiple_of) * multiple_of
1405
+ + image_height = (image_height // multiple_of) * multiple_of
1406
+ + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
1407
+ + condition_images.append(img)
1408
+ + height = height or image_height
1409
+ + width = width or image_width
1410
+ +
1411
+ + height = height or self.default_sample_size * self.vae_scale_factor
1412
+ + width = width or self.default_sample_size * self.vae_scale_factor
1413
+ +
1414
+ + # 5. prepare latent variables
1415
+ + num_channels_latents = self.transformer.config.in_channels // 4
1416
+ + latents, latent_ids = self.prepare_latents(
1417
+ + batch_size=batch_size * num_images_per_prompt,
1418
+ + num_latents_channels=num_channels_latents,
1419
+ + height=height,
1420
+ + width=width,
1421
+ + dtype=prompt_embeds.dtype,
1422
+ + device=device,
1423
+ + generator=generator,
1424
+ + latents=latents,
1425
+ + )
1426
+ +
1427
+ + image_latents = None
1428
+ + image_latent_ids = None
1429
+ + if condition_images is not None:
1430
+ + image_latents, image_latent_ids = self.prepare_image_latents(
1431
+ + images=condition_images,
1432
+ + batch_size=batch_size * num_images_per_prompt,
1433
+ + generator=generator,
1434
+ + device=device,
1435
+ + dtype=self.vae.dtype,
1436
+ + )
1437
+ +
1438
+ + # 6. Prepare timesteps
1439
+ + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1440
+ + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
1441
+ + sigmas = None
1442
+ + image_seq_len = latents.shape[1]
1443
+ + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
1444
+ + timesteps, num_inference_steps = retrieve_timesteps(
1445
+ + self.scheduler,
1446
+ + num_inference_steps,
1447
+ + device,
1448
+ + sigmas=sigmas,
1449
+ + mu=mu,
1450
+ + )
1451
+ + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1452
+ + self._num_timesteps = len(timesteps)
1453
+ +
1454
+ + # 7. Denoising loop with KV caching
1455
+ + # Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
1456
+ + # Steps 1+: forward_kv_cached (reuse cached ref K/V)
1457
+ + # No ref images: standard forward
1458
+ + self.scheduler.set_begin_index(0)
1459
+ + kv_cache = None
1460
+ +
1461
+ + with self.progress_bar(total=num_inference_steps) as progress_bar:
1462
+ + for i, t in enumerate(timesteps):
1463
+ + if self.interrupt:
1464
+ + continue
1465
+ +
1466
+ + self._current_timestep = t
1467
+ + timestep = t.expand(latents.shape[0]).to(latents.dtype)
1468
+ +
1469
+ + if i == 0 and image_latents is not None:
1470
+ + # Step 0: include ref tokens, extract KV cache
1471
+ + latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
1472
+ + latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
1473
+ +
1474
+ + output, kv_cache = self.transformer(
1475
+ + hidden_states=latent_model_input,
1476
+ + timestep=timestep / 1000,
1477
+ + guidance=None,
1478
+ + encoder_hidden_states=prompt_embeds,
1479
+ + txt_ids=text_ids,
1480
+ + img_ids=latent_image_ids,
1481
+ + joint_attention_kwargs=self.attention_kwargs,
1482
+ + return_dict=False,
1483
+ + kv_cache_mode="extract",
1484
+ + num_ref_tokens=image_latents.shape[1],
1485
+ + )
1486
+ + noise_pred = output[0]
1487
+ +
1488
+ + elif kv_cache is not None:
1489
+ + # Steps 1+: use cached ref KV, no ref tokens in input
1490
+ + noise_pred = self.transformer(
1491
+ + hidden_states=latents.to(self.transformer.dtype),
1492
+ + timestep=timestep / 1000,
1493
+ + guidance=None,
1494
+ + encoder_hidden_states=prompt_embeds,
1495
+ + txt_ids=text_ids,
1496
+ + img_ids=latent_ids,
1497
+ + joint_attention_kwargs=self.attention_kwargs,
1498
+ + return_dict=False,
1499
+ + kv_cache=kv_cache,
1500
+ + kv_cache_mode="cached",
1501
+ + )[0]
1502
+ +
1503
+ + else:
1504
+ + # No reference images: standard forward
1505
+ + noise_pred = self.transformer(
1506
+ + hidden_states=latents.to(self.transformer.dtype),
1507
+ + timestep=timestep / 1000,
1508
+ + guidance=None,
1509
+ + encoder_hidden_states=prompt_embeds,
1510
+ + txt_ids=text_ids,
1511
+ + img_ids=latent_ids,
1512
+ + joint_attention_kwargs=self.attention_kwargs,
1513
+ + return_dict=False,
1514
+ + )[0]
1515
+ +
1516
+ + # compute the previous noisy sample x_t -> x_t-1
1517
+ + latents_dtype = latents.dtype
1518
+ + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1519
+ +
1520
+ + if latents.dtype != latents_dtype:
1521
+ + if torch.backends.mps.is_available():
1522
+ + latents = latents.to(latents_dtype)
1523
+ +
1524
+ + if callback_on_step_end is not None:
1525
+ + callback_kwargs = {}
1526
+ + for k in callback_on_step_end_tensor_inputs:
1527
+ + callback_kwargs[k] = locals()[k]
1528
+ + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1529
+ +
1530
+ + latents = callback_outputs.pop("latents", latents)
1531
+ + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1532
+ +
1533
+ + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1534
+ + progress_bar.update()
1535
+ +
1536
+ + if XLA_AVAILABLE:
1537
+ + xm.mark_step()
1538
+ +
1539
+ + # Clean up KV cache
1540
+ + if kv_cache is not None:
1541
+ + kv_cache.clear()
1542
+ +
1543
+ + self._current_timestep = None
1544
+ +
1545
+ + latents = self._unpack_latents_with_ids(latents, latent_ids)
1546
+ +
1547
+ + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
1548
+ + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
1549
+ + latents.device, latents.dtype
1550
+ + )
1551
+ + latents = latents * latents_bn_std + latents_bn_mean
1552
+ + latents = self._unpatchify_latents(latents)
1553
+ + if output_type == "latent":
1554
+ + image = latents
1555
+ + else:
1556
+ + image = self.vae.decode(latents, return_dict=False)[0]
1557
+ + image = self.image_processor.postprocess(image, output_type=output_type)
1558
+ +
1559
+ + # Offload all models
1560
+ + self.maybe_free_model_hooks()
1561
+ +
1562
+ + if not return_dict:
1563
+ + return (image,)
1564
+ +
1565
+ + return Flux2PipelineOutput(images=image)
reference.jpg ADDED

Git LFS Details

  • SHA256: 863bfb672582b4d6f921d81de98ec1a1b27fcef4d140bc633028d1efd770e55a
  • Pointer size: 131 Bytes
  • Size of remote file: 401 kB
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.13.0
2
+ diffusers==0.37.0
3
+ huggingface_hub==1.5.0
4
+ sentencepiece==0.2.1
5
+ protobuf==3.20.3
6
+ transformers==5.3.0
7
+ torch==2.9.1
8
+ Pillow
styles.json ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "anime",
4
+ "name": "Anime",
5
+ "emoji": "🎌",
6
+ "description": "Clean high-fidelity anime character portrait",
7
+ "prompt": "Transform the uploaded person into a polished anime character portrait with clean cel-shaded line art, large expressive eyes, refined face simplification, crisp eyelash detailing, smooth gradient skin tones, vibrant but controlled color design, elegant hair separation, hand-drawn animation appeal, premium key visual composition, modern theatrical anime finish, and the luminous character-design quality of a top-tier Japanese animation studio."
8
+ },
9
+ {
10
+ "id": "ps2_y2k_gaming",
11
+ "name": "PS2 / Y2K Gaming",
12
+ "emoji": "🎮",
13
+ "description": "Sixth-generation console render with glossy Y2K game aesthetics",
14
+ "prompt": "Transform the uploaded person into a PlayStation 2 era character render with early-2000s Y2K gaming aesthetics, low-to-mid polygon facial construction, crisp geometric forms, glossy skin shader, chunky specular highlights, baked ambient occlusion, subtle texture compression, dramatic menu-screen lighting, saturated console-era color grading, metallic futuristic fashion cues, stylish 2000s promotional energy, and the polished character-select screen look of a premium sixth-generation action game.",
15
+ "preview_seed": 101
16
+ },
17
+ {
18
+ "id": "claymation_stop_motion",
19
+ "name": "Claymation / Stop Motion",
20
+ "emoji": "🧱",
21
+ "description": "Handmade stop-motion puppet built from sculpted clay",
22
+ "prompt": "Transform the uploaded person into a handcrafted claymation stop-motion puppet made from sculpted plasticine, visible fingerprints in the clay, slightly uneven handmade facial symmetry, pressed clothing details, tactile matte surfaces, miniature set lighting, practical stop-motion puppet proportions, tiny prop realism, shallow miniature depth of field, charming handcrafted imperfections, and the cozy premium look of a photographed studio stop-motion character.",
23
+ "preview_seed": 102
24
+ },
25
+ {
26
+ "id": "pop_mart_blind_box",
27
+ "name": "Pop Mart / Blind Box Toy",
28
+ "emoji": "🧸",
29
+ "description": "Cute designer toy collectible with glossy vinyl finish",
30
+ "prompt": "Transform the uploaded person into a premium blind-box designer toy collectible, oversized head-to-body proportions, rounded limbs, simplified but recognizable facial features, glossy vinyl material, soft pastel luxury palette, tiny mouth, blush gradients, polished paint application, immaculate factory finish, cute shelf-display appeal, and the upscale commercial product-shot aesthetic of a high-end collectible art toy.",
31
+ "preview_seed": 103
32
+ },
33
+ {
34
+ "id": "gta_loading_screen",
35
+ "name": "GTA Loading Screen",
36
+ "emoji": "🚘",
37
+ "description": "Bold illustrated splash art inspired by classic GTA loading screens",
38
+ "prompt": "Transform the uploaded person into a bold Grand Theft Auto style loading-screen illustration, confident character attitude, sharp graphic outlines, semi-realistic comic-book painting, smooth digital airbrush shading, punchy warm highlights, high-contrast skin rendering, glamorous urban styling, poster-like color blocking, polished splash-art finish, and the unmistakable blockbuster open-world crime-game loading screen aesthetic.",
39
+ "preview_seed": 104
40
+ },
41
+ {
42
+ "id": "retro_90s_yearbook",
43
+ "name": "90s Yearbook / Retro",
44
+ "emoji": "📚",
45
+ "description": "Nostalgic 1990s school portrait with analog studio charm",
46
+ "prompt": "Transform the uploaded person into a 1990s yearbook portrait with mall-studio photography aesthetics, soft frontal flash, mottled blue or purple backdrop, gentle diffusion, natural analog skin tones, subtle film grain, slightly faded print quality, awkwardly sincere school-photo energy, retro grooming cues, thin halation, and the authentic printed look of a real 90s yearbook headshot.",
47
+ "preview_seed": 105
48
+ },
49
+ {
50
+ "id": "ghibli_90s_anime",
51
+ "name": "Ghibli / 90s Anime",
52
+ "emoji": "🌿",
53
+ "description": "Warm hand-painted 1990s fantasy anime portrait",
54
+ "prompt": "Transform the uploaded person into a warm 1990s fantasy anime portrait with gentle cel animation linework, soft hand-painted color styling, expressive but calm eyes, youthful face simplification, watercolor atmosphere, delicate blush tones, magical natural light, pastoral fantasy warmth, cinematic emotional storytelling, and the nostalgic painted-film quality of a classic 1990s anime feature.",
55
+ "preview_seed": 106
56
+ },
57
+ {
58
+ "id": "inflatable_balloon_art",
59
+ "name": "Inflatable / 3D Balloon Art",
60
+ "emoji": "🎈",
61
+ "description": "Glossy inflatable sculpture with oversized rounded forms",
62
+ "prompt": "Transform the uploaded person into a glossy inflatable balloon-art sculpture, oversized rounded anatomy, sealed vinyl seams, ultra-smooth reflective material, playful exaggerated volume, candy-like color design, high-specular reflections, inflated clothing shapes, whimsical premium 3D render quality, and the luxurious visual language of high-end contemporary balloon sculpture art.",
63
+ "preview_seed": 107
64
+ },
65
+ {
66
+ "id": "pixel_art_16bit",
67
+ "name": "16-Bit Pixel Art",
68
+ "emoji": "👾",
69
+ "description": "Detailed 16-bit console sprite portrait",
70
+ "prompt": "Transform the uploaded person into detailed 16-bit pixel art, carefully reduced facial features for sprite readability, limited retro palette, crisp pixel clusters, deliberate dithering, SNES-era RPG portrait design, bold silhouette readability, expressive eyes built from minimal pixel forms, clean sprite shading, cartridge-era charm, and the polished character-portrait look of a premium 1990s console role-playing game.",
71
+ "preview_seed": 108
72
+ },
73
+ {
74
+ "id": "clay_sculpture",
75
+ "name": "Clay",
76
+ "emoji": "🪵",
77
+ "description": "Handcrafted clay or ceramic bust with tactile sculpted detail",
78
+ "prompt": "Transform the uploaded person into a handcrafted clay sculpture portrait, tactile earthen material, sculpted cheekbones and hair ridges, ceramic or raw-clay texture variation, hand-tool marks, subtle cracks, carved garment detail, grounded natural palette, gallery-style side lighting, and the substantial physical presence of a carefully sculpted clay bust.",
79
+ "preview_seed": 109
80
+ },
81
+ {
82
+ "id": "toy_collectible",
83
+ "name": "Toy",
84
+ "emoji": "🧸",
85
+ "description": "Stylized collectible toy figure with premium retail finish",
86
+ "prompt": "Transform the uploaded person into a stylized collectible toy figure, premium molded plastic materials, clean manufactured edges, slightly enlarged head, simplified but recognizable likeness, painted toy detailing, glossy eyes, shelf-ready silhouette, commercial product lighting, and the finish quality of an officially released limited-edition character toy.",
87
+ "preview_seed": 110
88
+ },
89
+ {
90
+ "id": "sims_style",
91
+ "name": "The Sims Style",
92
+ "emoji": "💚",
93
+ "description": "Friendly life-simulation avatar with polished game-render styling",
94
+ "prompt": "Transform the uploaded person into a life-simulation game avatar with polished The Sims inspired styling, bright approachable face design, clean character-creator grooming, smooth game-render skin, slightly idealized proportions, cheerful suburban presentation, crisp digital lighting, customizable-avatar appeal, and the unmistakable promotional look of a premium life-sim character portrait.",
95
+ "preview_seed": 111
96
+ },
97
+ {
98
+ "id": "gta_character_art",
99
+ "name": "GTA Style",
100
+ "emoji": "🔫",
101
+ "description": "High-impact urban crime-game promotional character art",
102
+ "prompt": "Transform the uploaded person into stylized open-world crime-game key art with GTA-inspired visual language, sharp facial planes, smooth posterized shading, assertive street attitude, bold edge highlights, vivid urban contrast, semi-realistic comic painting, polished marketing illustration finish, and the high-impact character-poster energy of a blockbuster crime saga game.",
103
+ "preview_seed": 112
104
+ },
105
+ {
106
+ "id": "minecraft_style",
107
+ "name": "Minecraft Style",
108
+ "emoji": "🟩",
109
+ "description": "Block-built voxel avatar with iconic sandbox-game geometry",
110
+ "prompt": "Transform the uploaded person into a Minecraft-inspired voxel character, square head, cubic limbs, block-built anatomy, pixel-textured clothing, low-resolution texture mapping, readable blocky hairstyle translation, simplified face, bright sandbox-world lighting, chunky geometric construction, and the instantly recognizable look of a polished block-game avatar.",
111
+ "preview_seed": 113
112
+ },
113
+ {
114
+ "id": "roblox_style",
115
+ "name": "Roblox Style",
116
+ "emoji": "🟥",
117
+ "description": "Playful modular avatar with clean platform-game styling",
118
+ "prompt": "Transform the uploaded person into a Roblox-inspired avatar portrait, modular toy-like body construction, simplified geometry, glossy materials, playful proportions, readable face design, bright platform-game lighting, fashionable avatar styling, clean digital render finish, and the vibrant social-gaming look of a premium customizable online avatar.",
119
+ "preview_seed": 114
120
+ },
121
+ {
122
+ "id": "vhs_style",
123
+ "name": "VHS Style",
124
+ "emoji": "📼",
125
+ "description": "Degraded analog tape still with authentic VHS artifacts",
126
+ "prompt": "Transform the uploaded person into a VHS-era analog video still, interlaced scanlines, tape noise, chromatic bleed, tracking distortion, magnetic dropouts, dated camcorder color response, low-resolution softness, blown highlights, shadow crush, slight timestamp-era nostalgia, and the authentic degraded feeling of a paused home-video frame transferred from tape.",
127
+ "preview_seed": 115
128
+ },
129
+ {
130
+ "id": "polaroid_style",
131
+ "name": "Polaroid Style",
132
+ "emoji": "🖼️",
133
+ "description": "Instant-film portrait with soft flash and chemical color shifts",
134
+ "prompt": "Transform the uploaded person into a classic Polaroid instant-photo portrait, soft flash falloff, creamy analog color shifts, lightly washed highlights, vintage chemical film response, subtle edge fading, gentle grain, tactile print nostalgia, candid warmth, imperfect instant-camera sharpness, and the authentic charm of a real instant photograph.",
135
+ "preview_seed": 116
136
+ },
137
+ {
138
+ "id": "webtoon_style",
139
+ "name": "Webtoon Style",
140
+ "emoji": "📱",
141
+ "description": "Glossy serialized-comic portrait with polished webtoon rendering",
142
+ "prompt": "Transform the uploaded person into a polished webtoon character portrait, elegant clean line art, luminous skin rendering, expressive stylized eyes, soft gradient shading, romantic dramatic atmosphere, fashion-forward character styling, glossy digital coloring, refined facial simplification, and the premium visual quality of a top-tier Korean webtoon cover panel.",
143
+ "preview_seed": 117
144
+ },
145
+ {
146
+ "id": "pixar_style",
147
+ "name": "Pixar Style",
148
+ "emoji": "✨",
149
+ "description": "Cinematic stylized 3D animated character render",
150
+ "prompt": "Transform the uploaded person into a stylized cinematic 3D animated character with Pixar-inspired appeal, expressive eyes, carefully designed facial proportions, believable subsurface scattering, smooth skin shading, clean hair groom, emotionally readable expression, art-directed color harmony, cinematic key-and-fill lighting, and the premium theatrical quality of a feature-animation hero render.",
151
+ "preview_seed": 118
152
+ },
153
+ {
154
+ "id": "wizard_school_fantasy",
155
+ "name": "Wizard School Identity",
156
+ "emoji": "🪄",
157
+ "description": "Magical academy identity portrait in a spellbound school world",
158
+ "prompt": "Transform the uploaded person into a magical boarding-school fantasy character portrait, scholarly robe tailoring, enchanted academic atmosphere, candlelit old-world interiors, brass and parchment details, subtle magical particles, house-color accents, moody castle lighting, elegant fantasy realism, and the immersive identity-shift feeling of belonging inside a grand wizard school universe.",
159
+ "preview_seed": 119
160
+ },
161
+ {
162
+ "id": "dune_desert_epic",
163
+ "name": "Dune Identity",
164
+ "emoji": "🏜️",
165
+ "description": "Mythic desert sci-fi identity with austere cinematic grandeur",
166
+ "prompt": "Transform the uploaded person into a desert-epic science-fiction character, windswept arid atmosphere, monumental cinematic seriousness, sand-toned palette, survivalist desert garments, futuristic nomadic detailing, austere regal styling, harsh sunlit contrast, drifting dust, mythic scale, and the solemn high-fashion science-fantasy presence of a legendary desert-world figure.",
167
+ "preview_seed": 120
168
+ },
169
+ {
170
+ "id": "viking_saga",
171
+ "name": "Viking Identity",
172
+ "emoji": "🪓",
173
+ "description": "Rugged Norse saga portrait with warrior realism",
174
+ "prompt": "Transform the uploaded person into a Viking saga character portrait, weathered Nordic styling, braided hair or fur-lined garments, carved metal detail, worn leather and wool textures, cold-climate realism, stormy dramatic light, warrior dignity, historical-fantasy grit, and the fierce mythic aura of a cinematic Norse hero.",
175
+ "preview_seed": 121
176
+ },
177
+ {
178
+ "id": "superhero_identity",
179
+ "name": "Super Hero Identity",
180
+ "emoji": "🦸",
181
+ "description": "Blockbuster hero reveal with stylized cinematic power",
182
+ "prompt": "Transform the uploaded person into a hyper-stylized superhero character portrait, iconic costume-design cues, bold chest-and-shoulder silhouette, cinematic heroic lighting, dramatic backlight, high-energy color blocking, polished comic-to-film realism, powerful confident expression, and the larger-than-life franchise-poster aura of a modern blockbuster hero reveal.",
183
+ "preview_seed": 122
184
+ },
185
+ {
186
+ "id": "paper_cutout",
187
+ "name": "Paper Cutout",
188
+ "emoji": "✂️",
189
+ "description": "Layered paper collage with tactile handcrafted depth",
190
+ "prompt": "Transform the uploaded person into a layered paper-cut illustration, hand-cut cardstock shapes, stacked paper planes, crisp scissor-cut edges, soft cast shadows between layers, simplified graphic facial features, subtle paper fiber texture, handcrafted composition, and the elegant dimensional charm of premium artisanal papercraft artwork.",
191
+ "preview_seed": 123
192
+ },
193
+ {
194
+ "id": "baroque",
195
+ "name": "Baroque",
196
+ "emoji": "👑",
197
+ "description": "Ornate old-master portrait with dramatic Baroque lighting",
198
+ "prompt": "Transform the uploaded person into a grand Baroque portrait painting, rich velvet and silk costume textures, ornate gold embellishment, museum-grade oil-paint surface, dramatic chiaroscuro, deep tenebrist shadows, warm candlelit highlights, aristocratic posture, old-master brushwork, and the opulent 17th-century grandeur of a prestigious European court portrait.",
199
+ "preview_seed": 124
200
+ },
201
+ {
202
+ "id": "dark_fantasy",
203
+ "name": "Dark Fantasy",
204
+ "emoji": "🕯️",
205
+ "description": "Brooding gothic fantasy portrait with ominous atmosphere",
206
+ "prompt": "Transform the uploaded person into a dark fantasy character portrait, gothic worldbuilding atmosphere, ominous candlelit shadows, ancient armor or ritual garments, ash-and-blood color accents, weathered surfaces, haunted elegance, moody volumetric haze, cursed-kingdom tone, and the brooding visual depth of a prestige dark-fantasy epic.",
207
+ "preview_seed": 125
208
+ },
209
+ {
210
+ "id": "cyberpunk_neon_noir",
211
+ "name": "Cyberpunk / Neon Noir",
212
+ "emoji": "🌆",
213
+ "description": "Rain-soaked futuristic noir lit by neon signage",
214
+ "prompt": "Transform the uploaded person into a neon-noir cyberpunk portrait, rain-slick reflections, magenta and cyan practical lighting, holographic signage glow, moody urban darkness, reflective synthetic fabrics, futuristic street-culture styling, soft atmospheric mist, glossy skin highlights from neon, and the immersive nighttime mood of a premium neon-soaked sci-fi thriller.",
215
+ "preview_seed": 126
216
+ },
217
+ {
218
+ "id": "ukiyoe_style",
219
+ "name": "Ukiyo-e",
220
+ "emoji": "🌊",
221
+ "description": "Elegant Japanese floating-world portrait print",
222
+ "prompt": "Transform the uploaded person into a refined ukiyo-e portrait, graceful flattened perspective, controlled ink outlines, flowing garment silhouettes, muted mineral pigments, washi-paper feel, decorative hair treatment, balanced negative space, Edo-period print sensibility, and the serene crafted beauty of a classic Japanese floating-world print.",
223
+ "preview_seed": 127
224
+ },
225
+ {
226
+ "id": "woodblock_print",
227
+ "name": "Woodblock",
228
+ "emoji": "🪵",
229
+ "description": "Hand-pulled relief print with carved texture and ink variation",
230
+ "prompt": "Transform the uploaded person into a traditional woodblock print portrait, carved-line texture, visible ink transfer irregularities, reduced tonal shapes, tactile handmade printmaking character, rough-paper grain, artisan block-carved detail, limited print palette, and the physical analog feeling of a carefully hand-pulled relief print.",
231
+ "preview_seed": 128
232
+ },
233
+ {
234
+ "id": "action_figure",
235
+ "name": "Action Figure",
236
+ "emoji": "🦾",
237
+ "description": "Collector-grade articulated figure based on the subject",
238
+ "prompt": "Transform the uploaded person into a premium articulated action figure, realistic molded head sculpt, visible joint engineering, semi-matte plastic body, painted costume detail, collector-grade product lighting, hero-franchise presentation, crisp accessory-ready design language, and the convincing look of a high-end retail action figure based on the subject.",
239
+ "preview_seed": 129
240
+ },
241
+ {
242
+ "id": "doll_packaging",
243
+ "name": "Doll Packaging",
244
+ "emoji": "🎁",
245
+ "description": "Fashion doll presented inside premium retail box packaging",
246
+ "prompt": "Transform the uploaded person into a fashion doll presentation inside premium retail packaging, stylized doll proportions, glossy plastic face paint, immaculate rooted-hair appearance, product-window box framing, coordinated accessories, bright commercial toy photography lighting, shelf-display glamour, and the fully merchandised elegance of a high-end doll-box release.",
247
+ "preview_seed": 130
248
+ },
249
+ {
250
+ "id": "crochet_style",
251
+ "name": "Crochet",
252
+ "emoji": "🧶",
253
+ "description": "Hand-crocheted doll portrait with visible looped stitching",
254
+ "prompt": "Transform the uploaded person into a handcrafted crochet figure portrait, visible looped stitches, yarn-built facial features, soft stuffed volume, carefully stitched clothing detail, tactile fiber texture, warm craft-table atmosphere, subtle handmade irregularity, and the lovable artisanal quality of an expertly crocheted doll.",
255
+ "preview_seed": 131
256
+ },
257
+ {
258
+ "id": "yarn_plush",
259
+ "name": "Yarn",
260
+ "emoji": "🪡",
261
+ "description": "Soft plush character built from thick colorful yarn",
262
+ "prompt": "Transform the uploaded person into a plush yarn-crafted character, fluffy wool fibers, soft textile construction, rounded stuffed-toy anatomy, fuzzy tactile surface, adorable stitched detailing, cozy handmade softness, gentle nursery-style lighting, and the charming volume of a premium yarn plush character.",
263
+ "preview_seed": 132
264
+ },
265
+ {
266
+ "id": "creepy_doll",
267
+ "name": "Creepy Doll",
268
+ "emoji": "🪆",
269
+ "description": "Uncanny haunted-doll portrait with aged porcelain detail",
270
+ "prompt": "Transform the uploaded person into an eerie creepy-doll portrait, glassy unsettling eyes, cracked porcelain texture, antique doll construction, faintly uncanny smile, distressed costume details, dim haunted-house lighting, vintage-toy horror atmosphere, subtle age and wear marks, and the disturbing but carefully crafted presence of a haunted collectible doll.",
271
+ "preview_seed": 133
272
+ },
273
+ {
274
+ "id": "ultra_ugly_caricature",
275
+ "name": "Ultra Ugly Caricature",
276
+ "emoji": "🤪",
277
+ "description": "Extreme grotesque caricature pushed for maximum comedic distortion",
278
+ "prompt": "Transform the uploaded person into an ultra-exaggerated ugly caricature, intentionally awkward proportions, amplified facial asymmetry, oversized nose and jaw cues when naturally present, absurd comedic distortion, lumpy stylization, grotesque editorial-cartoon boldness, ugly-but-recognizable likeness, and the maximum-pushed comedic exaggeration of a ruthless caricature artist rendering.",
279
+ "preview_seed": 134
280
+ },
281
+ {
282
+ "id": "classic_caricature",
283
+ "name": "Caricature",
284
+ "emoji": "🎭",
285
+ "description": "Playful editorial caricature that exaggerates signature features",
286
+ "prompt": "Transform the uploaded person into a polished caricature portrait, amplified signature facial traits, playful proportions, expressive linework, satirical illustration charm, painterly editorial finish, witty character design, lively exaggerated smile and eye shapes, and the appealing handcrafted look of a professional caricature artist creating a recognizable stylized likeness.",
287
+ "preview_seed": 135
288
+ }
289
+ ]