Latent Diffusion Model β€” Text-Conditional CelebA-HQ

A full Latent Diffusion Model (LDM) pipeline for text-conditional face image generation, trained on CelebA-HQ at 256x256 resolution. The model compresses images into a discrete latent space with VQ-VAE, then trains a text-conditioned U-Net diffusion model in that compressed space.

Model Description

This is a two-stage generative model:

  1. Stage 1 β€” VQ-VAE: Compresses 256x256 RGB images into a 4-channel discrete latent representation with a codebook of size 8192. Trained with reconstruction, perceptual (LPIPS/VGG), and adversarial (discriminator) losses.
  2. Stage 2 β€” LDM U-Net: A conditional U-Net diffusion model operating on the VQ-VAE latent space. Text conditioning is provided via frozen CLIP embeddings (512-dim) with classifier-free guidance.

VQ-VAE Architecture

Parameter Value
Latent channels (z) 4
Codebook size 8192
Down channels [64, 128, 256, 256]
Downsampling stages 3
Attention None in encoder/decoder
Loss components MSE (reconstruction) + LPIPS (perceptual) + GAN (adversarial)

LDM U-Net Architecture

Parameter Value
Down channels [256, 384, 512, 768]
Mid channels [768, 512]
Attention All down levels
Attention heads 16
Time embedding dim 512
Condition CLIP text embeddings (512-dim)
CFG dropout probability 0.1

Diffusion Process

Parameter Value
Timesteps (T) 1000
Beta schedule Linear, start=0.00085, end=0.012
Classifier-free guidance Enabled (cf_guidance_scale configurable)

Training Details

Stage Epochs LR Batch size
VQ-VAE 80 1e-5 4
LDM U-Net 100 5e-6 16
  • Dataset: CelebA-HQ, 256x256 RGB faces
  • Discriminator enabled after 15,000 steps (disc_start=15000)
  • Training tracked with Weights & Biases

Repository Contents

Path Description
models/vqvae.py VQ-VAE encoder/decoder with codebook
models/discriminator.py PatchGAN discriminator
models/lpips.py Perceptual loss (VGG-based LPIPS)
models/unet_cond.py Text-conditional LDM U-Net
models/blocks.py Shared building blocks
train_vqvae.py Stage 1 training script
train_ldm.py Stage 2 training script
scheduler.py Noise scheduler
config/celebahq.yaml Full training config
dataset/ CelebA-HQ parquet dataloader
celebhq/vqvae_autoencoder_ckpt.pth VQ-VAE checkpoint

How to Use

import yaml, torch
from models.vqvae import VQVAE
from models.unet_cond import UNet

with open("config/celebahq.yaml") as f:
    config = yaml.safe_load(f)

vqvae = VQVAE(**config["autoencoder_params"])
vqvae.load_state_dict(torch.load("celebhq/vqvae_autoencoder_ckpt.pth"))
vqvae.eval()

# Encode an image to latent space
with torch.no_grad():
    z, _, _ = vqvae.encode(image_tensor)
    reconstruction = vqvae.decode(z)

References

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for YashNagraj75/Latent-Diffusion-Conditional