"""Evaluate a W4A8 SVDQuant student: held-out velocity-matching loss vs the teacher
(directly comparable to the block-surgery numbers in RESULTS.md) + teacher-vs-quant
image montages.

Usage: python3 -u scripts/13_eval_svdquant.py [QUANT_DIR=outputs/svdquant_r32_...]
"""
import json
import os
import sys

import torch

from flux2distill.data import LatentCaptionDataset
from flux2distill.losses import velocity_match_loss, build_x_t
from flux2distill.model_utils import load_pipeline, load_transformer
from flux2distill.eval_utils import side_by_side
from flux2distill import svdquant as sq

QUANT_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("QUANT_DIR", "outputs/svdquant_r32_a0.5_w4a8")
cfg = json.load(open(f"{QUANT_DIR}/quant_config.json"))
OUT = f"{QUANT_DIR}/eval"
os.makedirs(OUT, exist_ok=True)
print(f"=== eval {QUANT_DIR}: W{cfg['w_bits']}A{cfg['a_bits']} rank={cfg['rank']} "
      f"alpha={cfg['alpha']} ===", flush=True)

pipe = load_pipeline(device="cuda")
teacher = pipe.transformer
teacher.eval().requires_grad_(False)

# rebuild the quantized student and load weights
student = load_transformer(dtype="bfloat16", device="cuda").eval()
sq.apply_svdquant_empty(student, cfg["specs"], w_bits=cfg["w_bits"], a_bits=cfg["a_bits"],
                        w_group=cfg["w_group"], a_group=cfg.get("a_group", 0),
                        dtype=torch.bfloat16)
# load via CPU: a cuda-mapped load would hold a 2nd full copy on the GPU (OOM on 32 GB)
_sd = torch.load(f"{QUANT_DIR}/quant_state.pt", map_location="cpu")
missing, unexpected = student.load_state_dict(_sd, strict=False)
assert not unexpected, f"unexpected keys: {unexpected[:5]}"
del _sd
torch.cuda.empty_cache()
student.requires_grad_(False)

# ---- fixed held-out eval batch (SAME construction as 08_train_recover for continuity) ----
ds = LatentCaptionDataset(cache_dir="data/monet_cache")
EVAL_N = 16
ev_x0 = ds.latents[:EVAL_N].to("cuda", torch.bfloat16)
ev_caps = ds.captions[:EVAL_N]
with torch.no_grad():
    ev_pe, ev_tid = pipe.encode_prompt(ev_caps, device="cuda")
ev_gen = torch.Generator(device="cuda").manual_seed(1234)
ev_eps = torch.randn(ev_x0.shape, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_sigma = torch.rand(EVAL_N, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_xt = build_x_t(ev_x0.float(), ev_eps, ev_sigma).to(torch.bfloat16)
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
                                  torch.Generator(device="cuda").manual_seed(0))

# prompts are encoded; the 7.5 GB text encoder is idle until the montages — park it in RAM
pipe.text_encoder.to("cpu")
torch.cuda.empty_cache()


def velocity(tf, x_t, sigma, pe, tid):
    out = tf(hidden_states=x_t, timestep=sigma, guidance=None,
             encoder_hidden_states=pe, txt_ids=tid, img_ids=img_ids, return_dict=False)[0]
    return out[:, : x_t.size(1)]


with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    vt = velocity(teacher, ev_xt, ev_sigma, ev_pe, ev_tid)
    vs = velocity(student, ev_xt, ev_sigma, ev_pe, ev_tid)
    loss = float(velocity_match_loss(vs, vt))
    # also report the relative L2 error of the velocity field
    rel = float((vs.float() - vt.float()).norm() / (vt.float().norm() + 1e-8))
print(f"eval_vel_loss={loss:.4f}   vel_rel_err={rel:.4f}", flush=True)

summ = cfg["summary"]
print(f"quant: {summ['n_quant_layers']} layers, effective {summ['quant_MB']:.0f}MB "
      f"({summ['ratio']:.2f}x), weight-recon mean rel-err {cfg['diag']['mean_rel_err']:.4f}",
      flush=True)

# ---- image montages: teacher vs quant on probe prompts ----
# First 4 are the original probes (kept for continuity); next 4 added 2026-06-01 for
# richer visual comparison — a 2nd text case, multi-object composition/counting, a
# hands/face case, and a fine-texture macro (the capabilities quant is most likely to bend).
PROMPTS = [
    'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
    "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
    "a photorealistic portrait of an elderly fisherman, weathered face, sharp detail",
    "a bustling tokyo street at night, neon signs, rain-slicked pavement, reflections",
    'a hand-lettered chalkboard cafe sign that reads "FRESH COFFEE" with small daily specials below',
    "a flat-lay breakfast table from above: three fried eggs, two strips of bacon, a glass of orange juice, and a small vase with one sunflower",
    "a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand",
    "an extreme macro of a dewy spider web at dawn, water droplets catching golden light, crisp detail",
]


pipe.text_encoder.to("cuda")
torch.cuda.empty_cache()


@torch.no_grad()
def gen(tf):
    pipe.transformer = tf
    g = torch.Generator(device="cuda").manual_seed(0)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        return pipe(prompt=PROMPTS, num_inference_steps=4, guidance_scale=1.0,
                    height=512, width=512, generator=g).images


t_imgs = gen(teacher)
torch.cuda.empty_cache()
q_imgs = gen(student)
for i, (t, q) in enumerate(zip(t_imgs, q_imgs)):
    side_by_side(t, q, "teacher", f"W{cfg['w_bits']}A{cfg['a_bits']} r{cfg['rank']}",
                 PROMPTS[i]).save(f"{OUT}/cmp_{i}.png")
print(f"saved {len(PROMPTS)} montages -> {OUT}/", flush=True)
