| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gradio as gr |
| | import torch |
| | import numpy as np |
| | from diffusers import StableDiffusionPipeline, DDIMScheduler |
| | from sklearn.decomposition import PCA |
| | import plotly.graph_objects as go |
| | from PIL import Image |
| | import time |
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| |
|
| | DEVICE = "cpu" |
| |
|
| | |
| | torch.backends.mkldnn.enabled = False |
| |
|
| | MODEL_ID = "CompVis/stable-diffusion-v1-4" |
| |
|
| | PIPE_CACHE = None |
| |
|
| |
|
| | |
| |
|
| | def get_pipe(): |
| | """ |
| | Load and cache the Stable Diffusion v1-4 pipeline on CPU, |
| | with safety checker DISABLED correctly. |
| | """ |
| | global PIPE_CACHE |
| | if PIPE_CACHE is not None: |
| | return PIPE_CACHE |
| |
|
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | safety_checker=None, |
| | requires_safety_checker=False |
| | ) |
| |
|
| | |
| | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
| |
|
| | pipe.to(DEVICE) |
| |
|
| | PIPE_CACHE = pipe |
| | return PIPE_CACHE |
| |
|
| |
|
| | |
| |
|
| | def compute_pca(latents): |
| | """ |
| | latents: list of (C,H,W) numpy arrays. |
| | Returns Nx2 array of PCA coords (one point per step). |
| | """ |
| | if not latents: |
| | return np.zeros((0, 2)) |
| | flat = [x.flatten() for x in latents] |
| | X = np.stack(flat) |
| | if X.shape[0] < 2: |
| | return np.zeros((X.shape[0], 2)) |
| | try: |
| | pca = PCA(n_components=2) |
| | pts = pca.fit_transform(X) |
| | return pts |
| | except Exception: |
| | return np.zeros((X.shape[0], 2)) |
| |
|
| |
|
| | def compute_norm(latents): |
| | """ |
| | L2 norm of each latent over all dims. |
| | """ |
| | if not latents: |
| | return [] |
| | return [float(np.linalg.norm(x.flatten())) for x in latents] |
| |
|
| |
|
| | |
| |
|
| | def decode_latent(pipe, latent_np): |
| | """ |
| | Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image. |
| | """ |
| | latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) |
| | scale = pipe.vae.config.scaling_factor |
| | with torch.no_grad(): |
| | image = pipe.vae.decode(latent / scale).sample |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") |
| | return Image.fromarray(np_img) |
| |
|
| |
|
| | |
| |
|
| | def run_diffusion(prompt, steps, guidance, seed, simple): |
| | """ |
| | Run SD v1-4 at 256x256, capturing latents at EVERY step via callback. |
| | Returns: |
| | - final image |
| | - explanation text |
| | - step slider config |
| | - image at current step |
| | - PCA plot |
| | - norm plot |
| | - state dict (for slider updates) |
| | """ |
| |
|
| | if not prompt or not prompt.strip(): |
| | return ( |
| | None, |
| | "⚠️ Please enter a prompt.", |
| | gr.update(maximum=0, value=0), |
| | None, |
| | None, |
| | None, |
| | {} |
| | ) |
| |
|
| | pipe = get_pipe() |
| |
|
| | steps = int(steps) |
| | guidance = float(guidance) |
| |
|
| | if seed is None or seed < 0: |
| | seed_val = int(time.time()) |
| | else: |
| | seed_val = int(seed) |
| |
|
| | generator = torch.Generator(device=DEVICE).manual_seed(seed_val) |
| |
|
| | latents_list = [] |
| | timesteps = [] |
| |
|
| | def callback(step: int, timestep: int, latents: torch.FloatTensor): |
| | |
| | latents_list.append(latents.detach().cpu().numpy()[0]) |
| | timesteps.append(int(timestep)) |
| |
|
| | t0 = time.time() |
| | try: |
| | result = pipe( |
| | prompt, |
| | height=256, |
| | width=256, |
| | num_inference_steps=steps, |
| | guidance_scale=guidance, |
| | generator=generator, |
| | callback=callback, |
| | callback_steps=1, |
| | ) |
| | except Exception as e: |
| | return ( |
| | None, |
| | f"❌ Diffusion error: {e}", |
| | gr.update(maximum=0, value=0), |
| | None, |
| | None, |
| | None, |
| | {"error": str(e)} |
| | ) |
| |
|
| | total = time.time() - t0 |
| |
|
| | if not latents_list: |
| | return ( |
| | None, |
| | "❌ No latents collected. Something went wrong inside the pipeline.", |
| | gr.update(maximum=0, value=0), |
| | None, |
| | None, |
| | None, |
| | {"error": "no_latents"} |
| | ) |
| |
|
| | final_image = result.images[0] |
| |
|
| | |
| | pca_pts = compute_pca(latents_list) |
| | norms = compute_norm(latents_list) |
| |
|
| | current_idx = len(latents_list) - 1 |
| |
|
| | |
| | try: |
| | step_image = decode_latent(pipe, latents_list[current_idx]) |
| | except Exception: |
| | step_image = None |
| |
|
| | |
| | if simple: |
| | explanation = ( |
| | "🧒 **Simple explanation of what you see:**\n\n" |
| | "1. The model starts from pure noise.\n" |
| | "2. At each step, it removes some noise and makes the picture clearer.\n" |
| | "3. Your text prompt tells it what kind of picture to create.\n" |
| | "4. You can move the slider to see the image at different steps.\n" |
| | ) |
| | else: |
| | explanation = ( |
| | "🔬 **Technical explanation:**\n\n" |
| | "- We run a DDIM diffusion process over the latent space.\n" |
| | "- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n" |
| | "- We record `zₜ` at every step and decode it with the VAE.\n" |
| | "- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n" |
| | "- The L2 norm plot shows how the latent magnitude evolves per step.\n" |
| | ) |
| | explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}" |
| |
|
| | |
| | pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None |
| | norm_fig = plot_norm(norms, current_idx) if norms else None |
| |
|
| | |
| | state = { |
| | "latents": latents_list, |
| | "pca": pca_pts, |
| | "norms": norms |
| | } |
| |
|
| | step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx) |
| |
|
| | return ( |
| | final_image, |
| | explanation, |
| | step_slider_update, |
| | step_image, |
| | pca_fig, |
| | norm_fig, |
| | state |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def plot_pca(points, idx): |
| | """ |
| | PCA trajectory plot over steps, highlighting current step. |
| | points: (N,2) |
| | """ |
| | if points.shape[0] == 0: |
| | return None |
| |
|
| | steps = list(range(points.shape[0])) |
| | fig = go.Figure() |
| | fig.add_trace(go.Scatter( |
| | x=points[:, 0], |
| | y=points[:, 1], |
| | mode="lines+markers", |
| | name="steps", |
| | text=[f"step {i}" for i in steps] |
| | )) |
| | if 0 <= idx < len(steps): |
| | fig.add_trace(go.Scatter( |
| | x=[points[idx, 0]], |
| | y=[points[idx, 1]], |
| | mode="markers+text", |
| | text=[f"step {idx}"], |
| | textposition="top center", |
| | marker=dict(size=12, color="red"), |
| | name="current" |
| | )) |
| | fig.update_layout( |
| | title="Latent PCA trajectory", |
| | xaxis_title="PC1", |
| | yaxis_title="PC2", |
| | height=350 |
| | ) |
| | return fig |
| |
|
| |
|
| | def plot_norm(norms, idx): |
| | """ |
| | Plot latent L2 norm vs step, highlight current step. |
| | """ |
| | if not norms: |
| | return None |
| | steps = list(range(len(norms))) |
| | fig = go.Figure() |
| | fig.add_trace(go.Scatter( |
| | x=steps, |
| | y=norms, |
| | mode="lines+markers", |
| | name="‖latent‖₂" |
| | )) |
| | if 0 <= idx < len(steps): |
| | fig.add_trace(go.Scatter( |
| | x=[idx], |
| | y=[norms[idx]], |
| | mode="markers", |
| | marker=dict(size=12, color="red"), |
| | name="current" |
| | )) |
| | fig.update_layout( |
| | title="Latent L2 norm vs step", |
| | xaxis_title="Step index", |
| | yaxis_title="‖latent‖₂", |
| | height=350 |
| | ) |
| | return fig |
| |
|
| |
|
| | |
| |
|
| | def update_step(state, idx): |
| | """ |
| | When user moves the slider: |
| | - decode latent at that step |
| | - update PCA highlight |
| | - update norm highlight |
| | """ |
| | if not state or "latents" not in state: |
| | return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
| |
|
| | latents = state["latents"] |
| | pca_pts = state["pca"] |
| | norms = state["norms"] |
| |
|
| | if not latents: |
| | return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
| |
|
| | idx = int(idx) |
| | idx = max(0, min(idx, len(latents) - 1)) |
| |
|
| | pipe = get_pipe() |
| |
|
| | try: |
| | img = decode_latent(pipe, latents[idx]) |
| | except Exception: |
| | img = None |
| |
|
| | pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None |
| | norm_fig = plot_norm(norms, idx) if norms else None |
| |
|
| | return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig) |
| |
|
| |
|
| | |
| |
|
| | with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo: |
| |
|
| | gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") |
| | gr.Markdown( |
| | "This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n" |
| | "- Uses `CompVis/stable-diffusion-v1-4` on CPU\n" |
| | "- 256×256 resolution for speed\n" |
| | "- You can scrub through all diffusion steps\n" |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | prompt = gr.Textbox( |
| | label="Prompt", |
| | value="a small cozy cabin in the forest, digital art", |
| | lines=3 |
| | ) |
| | steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps") |
| | guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale") |
| | seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) |
| | simple = gr.Checkbox(label="Simple explanation", value=True) |
| | run = gr.Button("Run diffusion", variant="primary") |
| |
|
| | with gr.Column(): |
| | final = gr.Image(label="Final generated image") |
| | expl = gr.Markdown(label="Explanation") |
| |
|
| | gr.Markdown("### 🔍 Explore the denoising process step-by-step") |
| |
|
| | step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)") |
| | step_img = gr.Image(label="Image at this diffusion step") |
| | pca_plot = gr.Plot(label="Latent PCA trajectory") |
| | norm_plot = gr.Plot(label="Latent norm vs step") |
| |
|
| | state = gr.State() |
| |
|
| | run.click( |
| | run_diffusion, |
| | inputs=[prompt, steps, guidance, seed, simple], |
| | outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] |
| | ) |
| |
|
| | step_slider.change( |
| | update_step, |
| | inputs=[state, step_slider], |
| | outputs=[step_img, pca_plot, norm_plot] |
| | ) |
| |
|
| | demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True) |