Latent Diffusion Model β Text-Conditional CelebA-HQ
A full Latent Diffusion Model (LDM) pipeline for text-conditional face image generation, trained on CelebA-HQ at 256x256 resolution. The model compresses images into a discrete latent space with VQ-VAE, then trains a text-conditioned U-Net diffusion model in that compressed space.
Model Description
This is a two-stage generative model:
- Stage 1 β VQ-VAE: Compresses 256x256 RGB images into a 4-channel discrete latent representation with a codebook of size 8192. Trained with reconstruction, perceptual (LPIPS/VGG), and adversarial (discriminator) losses.
- Stage 2 β LDM U-Net: A conditional U-Net diffusion model operating on the VQ-VAE latent space. Text conditioning is provided via frozen CLIP embeddings (512-dim) with classifier-free guidance.
VQ-VAE Architecture
| Parameter |
Value |
| Latent channels (z) |
4 |
| Codebook size |
8192 |
| Down channels |
[64, 128, 256, 256] |
| Downsampling stages |
3 |
| Attention |
None in encoder/decoder |
| Loss components |
MSE (reconstruction) + LPIPS (perceptual) + GAN (adversarial) |
LDM U-Net Architecture
| Parameter |
Value |
| Down channels |
[256, 384, 512, 768] |
| Mid channels |
[768, 512] |
| Attention |
All down levels |
| Attention heads |
16 |
| Time embedding dim |
512 |
| Condition |
CLIP text embeddings (512-dim) |
| CFG dropout probability |
0.1 |
Diffusion Process
| Parameter |
Value |
| Timesteps (T) |
1000 |
| Beta schedule |
Linear, start=0.00085, end=0.012 |
| Classifier-free guidance |
Enabled (cf_guidance_scale configurable) |
Training Details
| Stage |
Epochs |
LR |
Batch size |
| VQ-VAE |
80 |
1e-5 |
4 |
| LDM U-Net |
100 |
5e-6 |
16 |
- Dataset: CelebA-HQ, 256x256 RGB faces
- Discriminator enabled after 15,000 steps (disc_start=15000)
- Training tracked with Weights & Biases
Repository Contents
| Path |
Description |
| models/vqvae.py |
VQ-VAE encoder/decoder with codebook |
| models/discriminator.py |
PatchGAN discriminator |
| models/lpips.py |
Perceptual loss (VGG-based LPIPS) |
| models/unet_cond.py |
Text-conditional LDM U-Net |
| models/blocks.py |
Shared building blocks |
| train_vqvae.py |
Stage 1 training script |
| train_ldm.py |
Stage 2 training script |
| scheduler.py |
Noise scheduler |
| config/celebahq.yaml |
Full training config |
| dataset/ |
CelebA-HQ parquet dataloader |
| celebhq/vqvae_autoencoder_ckpt.pth |
VQ-VAE checkpoint |
How to Use
import yaml, torch
from models.vqvae import VQVAE
from models.unet_cond import UNet
with open("config/celebahq.yaml") as f:
config = yaml.safe_load(f)
vqvae = VQVAE(**config["autoencoder_params"])
vqvae.load_state_dict(torch.load("celebhq/vqvae_autoencoder_ckpt.pth"))
vqvae.eval()
with torch.no_grad():
z, _, _ = vqvae.encode(image_tensor)
reconstruction = vqvae.decode(z)
References
License
MIT