std3.5mt / app.py
programmersd's picture
Update app.py
667e250 verified
import torch
import gc
import time
import random
import os
import hashlib
import shutil
import psutil
from diffusers import DiffusionPipeline
import gradio as gr
from PIL import Image, PngImagePlugin
MODEL_ID = "tensorart/stable-diffusion-3.5-medium-turbo"
CACHE_DIR = "./hf_cache"
OUTPUT_DIR = "./outputs"
MAX_CACHE_SIZE_GB = 2
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = "cpu"
dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
safety_checker=None,
cache_dir=CACHE_DIR,
low_cpu_mem_usage=True
)
pipe.to(device)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.set_progress_bar_config(disable=True)
def warmup():
with torch.inference_mode():
pipe(
prompt="warmup",
num_inference_steps=1,
guidance_scale=0.0,
width=256,
height=256
)
gc.collect()
warmup()
def get_ram_usage():
return round(psutil.virtual_memory().used / (1024 ** 3), 2)
def prune_cache():
total_size = 0
files = []
for f in os.listdir(OUTPUT_DIR):
path = os.path.join(OUTPUT_DIR, f)
if os.path.isfile(path):
size = os.path.getsize(path)
total_size += size
files.append((path, size, os.path.getmtime(path)))
max_bytes = MAX_CACHE_SIZE_GB * 1024 * 1024 * 1024
if total_size <= max_bytes:
return
files.sort(key=lambda x: x[2])
for path, size, _ in files:
os.remove(path)
total_size -= size
if total_size <= max_bytes:
break
def build_cache_key(prompt, negative_prompt, steps, guidance, width, height, seed):
raw = f"{prompt}|{negative_prompt}|{steps}|{guidance}|{width}|{height}|{seed}"
return hashlib.sha256(raw.encode()).hexdigest()
def generate(prompt, negative_prompt, steps, guidance, width, height, seed):
start_time = time.time()
if not prompt.strip():
return None, "Prompt cannot be empty."
width = max(256, min(int(width), 768))
height = max(256, min(int(height), 768))
steps = max(1, min(int(steps), 8))
guidance = max(0.0, min(float(guidance), 7.5))
if seed == -1:
seed = random.randint(0, 2**32 - 1)
cache_key = build_cache_key(prompt, negative_prompt, steps, guidance, width, height, seed)
cache_path = os.path.join(OUTPUT_DIR, f"{cache_key}.png")
if os.path.exists(cache_path):
image = Image.open(cache_path)
duration = round(time.time() - start_time, 2)
ram = get_ram_usage()
return image, f"Loaded from cache | Seed: {seed} | Time: {duration}s | RAM: {ram}GB"
generator = torch.Generator(device=device).manual_seed(seed)
try:
with torch.inference_mode():
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
generator=generator
)
image = result.images[0]
metadata = PngImagePlugin.PngInfo()
metadata.add_text("prompt", prompt)
metadata.add_text("negative_prompt", negative_prompt)
metadata.add_text("steps", str(steps))
metadata.add_text("guidance", str(guidance))
metadata.add_text("seed", str(seed))
image.save(cache_path, pnginfo=metadata)
prune_cache()
duration = round(time.time() - start_time, 2)
ram = get_ram_usage()
gc.collect()
return image, f"Generated | Seed: {seed} | Time: {duration}s | RAM: {ram}GB"
except Exception as e:
gc.collect()
return None, f"Error: {str(e)}"
with gr.Blocks(title="SD 3.5 Turbo - Ultimate CPU Mode") as demo:
gr.Markdown("## Stable Diffusion 3.5 Medium Turbo - Ultimate CPU Edition")
with gr.Row():
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
with gr.Row():
steps = gr.Slider(1, 8, value=4, step=1, label="Steps")
guidance = gr.Slider(0.0, 7.5, value=0.0, step=0.5, label="Guidance")
with gr.Row():
width = gr.Slider(256, 768, value=512, step=64, label="Width")
height = gr.Slider(256, 768, value=512, step=64, label="Height")
seed = gr.Number(value=-1, label="Seed (-1 random)")
generate_btn = gr.Button("Generate")
output_image = gr.Image(type="pil")
status = gr.Textbox(label="Status")
generate_btn.click(
generate,
inputs=[prompt, negative_prompt, steps, guidance, width, height, seed],
outputs=[output_image, status]
)
demo.queue(max_size=10, concurrency_count=1, status_update_rate=1)
demo.launch()