artificialguybr's picture
Update app.py
13926bd verified
import os
import json
import random
import subprocess
import tempfile
import spaces
HF_TOKEN = os.environ.get("HF_TOKEN")
def apply_patch():
import diffusers
site_packages = os.path.dirname(diffusers.__file__)
patch_file = os.path.join(os.path.dirname(__file__), "flux2_klein_kv.patch")
if os.path.exists(patch_file):
result = subprocess.run(
["patch", "-p2", "--forward", "--batch"],
cwd=os.path.dirname(site_packages),
stdin=open(patch_file),
capture_output=True,
text=True,
)
apply_patch()
import numpy as np
import torch
from PIL import Image
from diffusers.pipelines.flux2.pipeline_flux2_klein_kv import Flux2KleinKVPipeline
import gradio as gr
APP_DIR = os.path.dirname(os.path.abspath(__file__))
STYLES_PATH = os.path.join(APP_DIR, "styles.json")
REFERENCE_PATH = os.path.join(APP_DIR, "reference.jpg")
MODEL_ID = "black-forest-labs/FLUX.2-klein-9b-kv"
MANDATORY_PROMPT = (
"You must preserve the exact identity, facial features, body proportions, and pose of the subject. "
"Keep the original camera framing, object positions, and full scene layout unchanged. "
"Apply only the requested visual style while maintaining total consistency and realism."
)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
with open(STYLES_PATH, "r", encoding="utf-8") as f:
STYLES = json.load(f)
pipe = Flux2KleinKVPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype, token=HF_TOKEN)
pipe.to("cuda")
PREVIEW_CACHE = {}
def prepare_image(image, max_size=1024):
iw, ih = image.size
ar = iw / ih
if ar >= 1:
w = max_size
h = round(max_size / ar / 16) * 16
else:
h = max_size
w = round(max_size * ar / 16) * 16
w, h = max(256, min(max_size, w)), max(256, min(max_size, h))
return image.resize((w, h), Image.LANCZOS), w, h
def build_prompt(style_prompt):
return f"{MANDATORY_PROMPT} {style_prompt}"
def save_generated_image(image):
fd, path = tempfile.mkstemp(prefix="style-my-portrait-", suffix=".png")
os.close(fd)
image.save(path, format="PNG")
return path
@spaces.GPU(duration=30)
def transform_image(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
if image is None:
raise gr.Error("Upload a photo first!")
style = next((s for s in STYLES if s["id"] == style_id), None)
if not style:
raise gr.Error("Select a valid style!")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
original_resized, w, h = prepare_image(image)
generator = torch.Generator(device=device).manual_seed(seed)
progress(0.2, desc=f"Applying {style['name']} style...")
result = pipe(
prompt=build_prompt(style["prompt"]),
image=[original_resized],
height=h,
width=w,
num_inference_steps=num_steps,
generator=generator,
).images[0]
progress(0.9, desc="Preparing before/after slider...")
return (original_resized, result), seed, original_resized, result
@spaces.GPU(duration=120)
def generate_previews(progress=gr.Progress(track_tqdm=True)):
global PREVIEW_CACHE
if PREVIEW_CACHE:
return PREVIEW_CACHE
if not os.path.exists(REFERENCE_PATH):
PREVIEW_CACHE = {s["id"]: None for s in STYLES}
return PREVIEW_CACHE
ref_img = Image.open(REFERENCE_PATH).convert("RGB")
ref_resized, w, h = prepare_image(ref_img)
for i, style in enumerate(STYLES):
progress(i / len(STYLES), desc=f"Generating preview: {style['name']}")
generator = torch.Generator(device=device).manual_seed(style.get("preview_seed", 42))
try:
res = pipe(
prompt=build_prompt(style["prompt"]),
image=[ref_resized],
height=h,
width=w,
num_inference_steps=4,
generator=generator,
).images[0]
PREVIEW_CACHE[style["id"]] = res
except Exception as e:
print(f"Error generating preview for {style['name']}: {e}")
PREVIEW_CACHE[style["id"]] = None
return PREVIEW_CACHE
def load_previews_to_gallery():
previews = generate_previews()
gallery_items = []
for style in STYLES:
img = previews.get(style["id"])
if img:
gallery_items.append((img, f"{style['emoji']} {style['name']}"))
else:
placeholder = Image.new("RGB", (512, 512), (20, 18, 30))
gallery_items.append((placeholder, f"{style['emoji']} {style['name']}"))
return gallery_items
def select_style_from_gallery(evt: gr.SelectData):
style = STYLES[evt.index]
return style["id"]
css = """
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
.gradio-container { background: radial-gradient(circle at 0% 0%, #0f1f2f 0%, #070f1f 60%, #060c18 100%) !important; max-width: 1120px !important; margin: 0 auto !important; font-family: 'Space Grotesk', sans-serif !important; }
.main-title h1 { text-align: center; color: #22d3ee !important; font-size: 2.6em !important; font-weight: 700 !important; text-shadow: 0 0 20px #22d3ee33; letter-spacing: -1px; margin-bottom: 0 !important; }
.subtitle p { text-align: center; color: #94a3b8 !important; font-size: 1.05em !important; margin-top: 0 !important; }
.step-guide p { text-align: center; color: #cbd5e1 !important; font-size: 1.03em !important; margin: 6px 0 14px !important; }
.credit-line p { text-align: center; color: #cbd5e1 !important; font-size: 0.98em !important; margin: 4px 0 20px !important; }
.credit-line a { color: #67e8f9 !important; text-decoration: none !important; font-weight: 700 !important; }
.credit-line a:hover { text-decoration: underline !important; }
#style-gallery { border: 1px solid #1e3a5f !important; border-radius: 12px !important; background: #0f1a30 !important; }
#style-gallery .grid-wrap { gap: 8px !important; }
#style-gallery .thumbnail-item { border: 2px solid transparent !important; border-radius: 10px !important; transition: all 0.2s ease !important; }
#style-gallery .thumbnail-item:hover { border-color: #22d3ee66 !important; transform: translateY(-2px); }
#style-gallery .thumbnail-item.selected,
#style-gallery .thumbnail-item[aria-selected="true"] {
border-color: #22d3ee !important;
box-shadow: 0 0 0 2px #22d3ee66, 0 10px 26px #22d3ee4d !important;
transform: translateY(-2px) scale(1.02) !important;
}
#input-img { border: 2px dashed #1e3a5f !important; border-radius: 12px !important; background: #0f1a30 !important; min-height: 320px; }
#output-slider { border: 1px solid #1e3a5f !important; border-radius: 12px !important; background: #0f1a30 !important; overflow: hidden !important; }
#go-btn { background: linear-gradient(135deg, #0ea5e9, #22d3ee) !important; color: #04131f !important; font-weight: 700 !important; font-size: 1.15em !important; min-height: 52px !important; border: none !important; border-radius: 12px !important; box-shadow: 0 4px 20px #22d3ee33; transition: all 0.2s ease !important; }
#go-btn:hover { box-shadow: 0 6px 26px #22d3ee4d; transform: translateY(-2px); }
#dl-btn { background: #0b1629 !important; color: #67e8f9 !important; border: 1px solid #1e3a5f !important; border-radius: 10px !important; }
.progress-bar { background-color: #22d3ee !important; }
.progress-bar-wrap { background-color: #0b1629 !important; }
* { --neutral-50: #0f1a30 !important; --neutral-100: #10203b !important; --neutral-200: #1e3a5f !important; }
.dark { --body-background-fill: #060c18; }
footer { display: none !important; }
"""
with gr.Blocks(title="Style My Portrait", css=css, theme=gr.themes.Base(
primary_hue=gr.themes.colors.cyan,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
font=gr.themes.GoogleFont("Space Grotesk"),
).set(
body_background_fill="#060c18",
body_background_fill_dark="#060c18",
block_background_fill="#0f1a30",
block_background_fill_dark="#0f1a30",
block_border_color="#1e3a5f",
block_border_color_dark="#1e3a5f",
block_label_text_color="#67e8f9",
block_label_text_color_dark="#67e8f9",
block_title_text_color="#67e8f9",
block_title_text_color_dark="#67e8f9",
body_text_color="#e2e8f0",
body_text_color_dark="#e2e8f0",
button_primary_background_fill="#22d3ee",
button_primary_background_fill_dark="#22d3ee",
button_primary_text_color="#04131f",
button_primary_text_color_dark="#04131f",
input_background_fill="#10203b",
input_background_fill_dark="#10203b",
input_border_color="#1e3a5f",
input_border_color_dark="#1e3a5f",
border_color_accent="#22d3ee",
border_color_accent_dark="#22d3ee",
border_color_primary="#1e3a5f",
border_color_primary_dark="#1e3a5f",
background_fill_secondary="#0f1a30",
background_fill_secondary_dark="#0f1a30",
background_fill_primary="#060c18",
background_fill_primary_dark="#060c18",
shadow_drop="none",
shadow_drop_lg="none",
slider_color="#22d3ee",
slider_color_dark="#22d3ee",
checkbox_background_color="#10203b",
checkbox_background_color_dark="#10203b",
checkbox_background_color_selected="#0ea5e9",
checkbox_background_color_selected_dark="#0ea5e9",
)) as demo:
gr.Markdown("# Style My Portrait", elem_classes="main-title")
gr.Markdown("Turn your portrait into polished visual styles with FLUX.2 Klein", elem_classes="subtitle")
gr.Markdown("**1)** Upload your photo • **2)** Pick a style from the gallery • **3)** Click transform", elem_classes="step-guide")
gr.Markdown(
"Built by [@artificialguybr](https://twitter.com/artificialguybr) • Explore more image editing and image generation prompts at [artificialguy.com](https://artificialguy.com) and [findgoodprompt.com](https://findgoodprompt.com)",
elem_classes="credit-line",
)
selected_style_id = gr.State(STYLES[0]["id"])
original_state = gr.State(None)
enhanced_state = gr.State(None)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_image = gr.Image(label="📸 Upload & Preview", type="pil", elem_id="input-img")
with gr.Accordion("⚙️ Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=20, step=1, value=4)
go_btn = gr.Button("🚀 Transform", elem_id="go-btn", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### 🎨 Gallery", elem_classes="subtitle")
style_gallery = gr.Gallery(
label="Style Gallery",
show_label=False,
elem_id="style-gallery",
columns=[4], rows=[2],
height=360,
object_fit="cover",
allow_preview=False,
value=load_previews_to_gallery()
)
style_gallery.select(
fn=select_style_from_gallery,
inputs=[],
outputs=[selected_style_id]
)
gr.Markdown("### 🖼️ Result", elem_classes="subtitle")
with gr.Row():
output_image = gr.ImageSlider(label="Before / After", type="pil", elem_id="output-slider", slider_position=50)
with gr.Row():
dl_btn = gr.DownloadButton("📥 Download Image", elem_id="dl-btn", visible=False)
def on_generate(image, style_id, seed, randomize_seed, num_steps, progress=gr.Progress(track_tqdm=True)):
comparison, seed, orig, enh = transform_image(image, style_id, seed, randomize_seed, num_steps, progress)
download_path = save_generated_image(enh)
return comparison, seed, orig, enh, gr.update(visible=True, value=download_path)
go_btn.click(
fn=on_generate,
inputs=[input_image, selected_style_id, seed, randomize_seed, num_inference_steps],
outputs=[output_image, seed, original_state, enhanced_state, dl_btn],
)
input_image.change(
fn=lambda: (None, gr.update(visible=False), None, None),
inputs=[],
outputs=[output_image, dl_btn, original_state, enhanced_state],
)
demo.launch(ssr_mode=False)