# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "transformers>=5.11,<6",
#   "torch>=2.8,<2.10",  # a100-large job image driver is CUDA 12.9; torch 2.10+ ships cu13 wheels
#   "torchvision",  # Gemma4Processor imports it even for text-only use
#   "accelerate",
#   "huggingface-hub",
#   "pillow",
#   "pyzipper",
#   "requests",
# ]
# ///
"""Generation runs for the DiffusionGemma vs Gemma-4 post-OCR correction benchmark.

GENERATION ONLY — all metrics are computed offline by metrics.py from the
JSONL this produces, so metric changes never require re-running GPU jobs.

Designed to run on HF Jobs:
    hf jobs uv run --flavor a100-large --timeout 45m \
      -e HF_XET_HIGH_PERFORMANCE=1 -s HF_TOKEN \
      benchmark.py -- --mode smoke
"""

import argparse
import difflib
import gc
import json
import random
import re
import time
from pathlib import Path

import requests

DG_MODEL = "google/diffusiongemma-26B-A4B-it"
G4_MODEL = "google/gemma-4-E4B-it"
# Parameter-matched AR baseline (26B MoE / 4B active, same as DiffusionGemma) —
# the model DeepMind benched DiffusionGemma against (per João Gante).
G4_MOE_MODEL = "google/gemma-4-26B-A4B-it"

# BLN600 (CC-BY-NC-4.0): resolved via the figshare API for DOI 10.15131/shef.data.25439023.
# NC license — the text is downloaded at run time and must never be committed
# to a public repo; raw outputs go to a PRIVATE dataset repo only.
BLN600_API = "https://api.figshare.com/v2/articles/25439023"
BLN600_ZIP_PASSWORD = b"BLN600"

# ICDAR2019 post-OCR (CC-BY-4.0) — fallback eval set + source of the Space's
# committable examples.
ICDAR_URL = "https://zenodo.org/records/3515403/files/ICDAR2019-POCR-ground-truth.zip"

MAX_PASSAGE_TOKENS = 220  # margin under DiffusionGemma's fixed 256-token canvas
N_EXAMPLES = 6  # Space dropdown examples (ICDAR, CC-BY)

PROMPT_TEMPLATE = """\
Correct the OCR errors in the following text from a 19th-century English newspaper.
Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \
spelling, do not rephrase, and do not add or remove content. Preserve the original \
punctuation unless it is clearly an OCR error.
Output only the corrected text, with no commentary or preamble.

OCR text:
{ocr}"""

STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>")


def extract_answer(raw: str) -> tuple[str, str]:
    """Split a raw decode into (answer, thought).

    DiffusionGemma's generated block looks like
    `<|channel>thought\\n<channel|>ANSWER<turn|><eos>...` even with thinking
    off (empty thought) — the answer is the text after the LAST `<channel|>`.
    Gemma-4 emits plain text; we just cut at the first stop marker.
    """
    stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1]
    if stops:
        raw = raw[: min(stops)]
    thought = ""
    if "<channel|>" in raw:
        head, _, raw = raw.rpartition("<channel|>")
        m = re.search(r"<\|channel>thought(.*)$", head, flags=re.DOTALL)
        if m:
            thought = m.group(1).strip()
    return raw.strip(), thought


# ---------------------------------------------------------------- data


def _download(url: str, dest: Path, **kwargs) -> Path:
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        return dest
    print(f"downloading {url} -> {dest}")
    with requests.get(url, stream=True, timeout=600, **kwargs) as r:
        r.raise_for_status()
        with dest.open("wb") as f:
            for chunk in r.iter_content(chunk_size=1 << 20):
                f.write(chunk)
    return dest


def download_bln600(workdir: Path) -> list[dict]:
    """Download + parse BLN600 into [{id, ocr_input, gold}]. Handles both a
    CSV layout ('OCR Text'/'Ground Truth' columns) and a folder layout
    (OCR Text/*.txt paired with Ground Truth/*.txt by stem)."""
    import pyzipper

    meta = requests.get(BLN600_API, timeout=60).json()
    zips = [f for f in meta["files"] if f["name"].lower().endswith(".zip")]
    if not zips:
        raise RuntimeError(f"no zip in figshare article files: {[f['name'] for f in meta['files']]}")
    zip_path = _download(zips[0]["download_url"], workdir / zips[0]["name"])

    extract_dir = workdir / "bln600"
    if not extract_dir.exists():
        with pyzipper.AESZipFile(zip_path) as zf:
            zf.setpassword(BLN600_ZIP_PASSWORD)
            zf.extractall(extract_dir)

    # folder layout: pair OCR Text/ and Ground Truth/ files by stem
    ocr_files = {p.stem: p for p in extract_dir.rglob("*.txt") if "ocr" in str(p.parent).lower()}
    gold_files = {
        p.stem: p
        for p in extract_dir.rglob("*.txt")
        if "ground" in str(p.parent).lower() or "gold" in str(p.parent).lower()
    }
    common = sorted(set(ocr_files) & set(gold_files))
    if common:
        print(f"BLN600 folder layout: {len(common)} aligned pairs")
        return [
            {
                "id": f"bln600/{stem}",
                "ocr_input": ocr_files[stem].read_text(errors="replace"),
                "gold": gold_files[stem].read_text(errors="replace"),
            }
            for stem in common
        ]

    # CSV layout fallback
    import csv

    for csv_path in extract_dir.rglob("*.csv"):
        with csv_path.open(newline="", errors="replace") as f:
            rows = list(csv.DictReader(f))
        if rows and "OCR Text" in rows[0] and "Ground Truth" in rows[0]:
            print(f"BLN600 CSV layout: {len(rows)} rows from {csv_path.name}")
            return [
                {"id": f"bln600/{i}", "ocr_input": r["OCR Text"], "gold": r["Ground Truth"]}
                for i, r in enumerate(rows)
            ]
    listing = [str(p.relative_to(extract_dir)) for p in list(extract_dir.rglob("*"))[:40]]
    raise RuntimeError(f"could not parse BLN600; archive contents: {listing}")


def download_icdar_english(workdir: Path) -> list[dict]:
    """ICDAR2019 post-OCR English subset. Format: per-passage .txt files with
    [OCR_toInput]/[OCR_aligned]/[ GS_aligned] lines; '@' are alignment pads."""
    import zipfile

    zip_path = _download(ICDAR_URL, workdir / "icdar2019.zip")
    extract_dir = workdir / "icdar2019"
    if not extract_dir.exists():
        with zipfile.ZipFile(zip_path) as zf:
            zf.extractall(extract_dir)

    passages = []
    for p in sorted(extract_dir.rglob("*.txt")):
        if not re.search(r"(^|/)EN", str(p.relative_to(extract_dir))):
            continue
        ocr = gold = None
        for line in p.read_text(errors="replace").splitlines():
            if line.startswith("[OCR_toInput]"):
                ocr = line.removeprefix("[OCR_toInput]").strip()
            elif line.startswith("[ GS_aligned]") or line.startswith("[GS_aligned]"):
                gold = re.sub("@", "", line.split("]", 1)[1]).strip()
        if ocr and gold:
            passages.append(
                {"id": f"icdar2019/{p.relative_to(extract_dir)}", "ocr_input": ocr, "gold": gold}
            )
    print(f"ICDAR2019 English: {len(passages)} passages")
    return passages


def trim_pair(ocr: str, gold: str, n_tokens, max_tokens: int) -> tuple[str, str] | None:
    """Trim an aligned (ocr, gold) pair so both sides fit in max_tokens.

    Cuts at a whitespace position inside a character-aligned "equal" region so
    the pair stays parallel after trimming (independent token-count truncation
    would misalign the endings and corrupt tail CER). Returns None if no valid
    cut point exists.
    """
    if n_tokens(ocr) <= max_tokens and n_tokens(gold) <= max_tokens:
        return ocr, gold

    sm = difflib.SequenceMatcher(None, ocr, gold, autojunk=False)
    # candidate (i_cut, j_cut) pairs, ascending: whitespace inside equal blocks
    candidates = [
        (i1 + m.start(), j1 + m.start())
        for op, i1, i2, j1, _j2 in sm.get_opcodes()
        if op == "equal"
        for m in re.finditer(r"\s", ocr[i1:i2])
    ]
    if not candidates:
        return None

    def fits(idx: int) -> bool:
        i_cut, j_cut = candidates[idx]
        return n_tokens(ocr[:i_cut]) <= max_tokens and n_tokens(gold[:j_cut]) <= max_tokens

    if not fits(0):
        return None
    # token counts grow with cut position -> binary search the largest fit
    lo, hi = 0, len(candidates) - 1
    while lo < hi:
        mid = (lo + hi + 1) // 2
        if fits(mid):
            lo = mid
        else:
            hi = mid - 1
    i_cut, j_cut = candidates[lo]
    return ocr[:i_cut].rstrip(), gold[:j_cut].rstrip()


def sample_passages(passages: list[dict], n: int, seed: int) -> list[dict]:
    """Deterministic sample; pairs longer than the canvas are align-trimmed."""
    from transformers import AutoTokenizer

    tok = AutoTokenizer.from_pretrained(G4_MODEL)

    def n_tokens(text: str) -> int:
        return len(tok(text)["input_ids"])

    chosen = random.Random(seed).sample(passages, len(passages))  # seeded order
    out: list[dict] = []
    n_trimmed = n_skipped = 0
    for p in chosen:
        if len(out) >= n:
            break
        trimmed = trim_pair(p["ocr_input"], p["gold"], n_tokens, MAX_PASSAGE_TOKENS)
        if trimmed is None or len(trimmed[1]) < 200:  # drop degenerate/too-short cuts
            n_skipped += 1
            continue
        if trimmed != (p["ocr_input"], p["gold"]):
            n_trimmed += 1
        out.append({"id": p["id"], "ocr_input": trimmed[0], "gold": trimmed[1]})
    print(
        f"sampled {len(out)} passages ({n_trimmed} trimmed to <= {MAX_PASSAGE_TOKENS} "
        f"tokens, {n_skipped} skipped as untrimmable/too short)"
    )
    return out


# ---------------------------------------------------------------- generation


def clean_output(text: str) -> str:
    cleaned = re.sub(r"^\s*corrected text:?\s*", "", text.strip(), flags=re.IGNORECASE)
    if cleaned != text.strip():
        print("  [clean_output stripped a prefix]")
    return cleaned


def count_generated_tokens(generated_ids, tokenizer) -> int:
    """Non-pad tokens up to (excluding) the first EOS."""
    ids = generated_ids.tolist()
    stop_ids = {tokenizer.eos_token_id, tokenizer.pad_token_id}
    count = 0
    for tid in ids:
        if tid in stop_ids:
            break
        count += 1
    return count


def run_model(model_key: str, passages: list[dict], smoke: bool) -> dict[str, dict]:
    """Load one model, run all passages, free the model. Returns id -> output dict.

    model_key "diffusiongemma_canvas" = same model, but the denoising canvas is
    initialised with the OCR text (via the undocumented `decoder_input_ids`
    hook in DiffusionGemmaForBlockDiffusion.generate) instead of random tokens
    — testing whether correction-as-denoising stays closer to the input.
    """
    import torch
    from transformers import AutoProcessor

    canvas_init = model_key == "diffusiongemma_canvas"
    is_dg = model_key.startswith("diffusiongemma")
    if is_dg:
        from transformers import DiffusionGemmaForBlockDiffusion, TextDiffusionStreamer

        class StepCountingStreamer(TextDiffusionStreamer):
            """Counts denoising steps; suppresses the default console printing
            (the parent prints every draft with ANSI rewrites — unusable in job logs)."""

            def __init__(self, tokenizer):
                super().__init__(tokenizer=tokenizer)
                self.n_steps = 0

            def put_draft(self, value, **kwargs):
                self.n_steps += 1

            def put(self, value):
                pass

            def end(self):
                pass

        model_id = DG_MODEL
        print(f"loading {model_id} ...")
        processor = AutoProcessor.from_pretrained(model_id)
        model = DiffusionGemmaForBlockDiffusion.from_pretrained(
            model_id, dtype="auto", device_map="auto"
        )
    else:
        from transformers import AutoModelForMultimodalLM

        model_id = G4_MOE_MODEL if model_key == "gemma4_moe" else G4_MODEL
        print(f"loading {model_id} ...")
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForMultimodalLM.from_pretrained(model_id, dtype="auto", device_map="auto")

    tokenizer = processor.tokenizer
    canvas_rng = torch.Generator().manual_seed(0)  # deterministic canvas tail padding

    def generate(ocr_text: str) -> dict:
        message = [{"role": "user", "content": PROMPT_TEMPLATE.format(ocr=ocr_text)}]
        inputs = processor.apply_chat_template(
            message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
        ).to(model.device)
        input_len = inputs["input_ids"].shape[-1]

        gen_kwargs: dict = {"max_new_tokens": 256}
        streamer = None
        if is_dg:
            # generation_config defaults for the entropy sampler (no greedy equivalent)
            streamer = StepCountingStreamer(tokenizer)
            gen_kwargs["streamer"] = streamer
            if canvas_init:
                # Seed the first denoising canvas with the OCR text instead of
                # random tokens; pad the tail with random tokens as the sampler
                # would. Canvas must be exactly canvas_length wide.
                canvas_length = getattr(model.generation_config, "canvas_length", None) or 256
                ids = tokenizer(ocr_text, add_special_tokens=False)["input_ids"][:canvas_length]
                vocab = model.config.text_config.vocab_size
                pad = torch.randint(vocab, (canvas_length - len(ids),), generator=canvas_rng)
                canvas = torch.cat([torch.tensor(ids, dtype=torch.long), pad])
                gen_kwargs["decoder_input_ids"] = canvas.unsqueeze(0).to(model.device)
        else:
            gen_kwargs["do_sample"] = False  # greedy

        torch.cuda.synchronize()
        t0 = time.perf_counter()
        output = model.generate(**inputs, **gen_kwargs)
        torch.cuda.synchronize()
        seconds = time.perf_counter() - t0

        # DiffusionGemma returns a DiffusionGemmaGenerationOutput (sequences
        # includes the prompt, like AR generate); plain tensor for Gemma-4.
        seq = output.sequences if hasattr(output, "sequences") else output
        generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0]
        tpf = getattr(output, "tokens_per_forward", None)
        if torch.is_tensor(tpf):
            tpf = int(tpf.flatten()[0])
        raw = tokenizer.decode(generated, skip_special_tokens=False)
        answer, thought = extract_answer(raw)
        if thought:
            print(f"  [WARNING: thought content present ({len(thought)} chars)]")
        return {
            "text": clean_output(answer),
            "_raw": raw,
            "seconds": round(seconds, 3),
            "tokens_generated": count_generated_tokens(generated, tokenizer),
            "denoising_steps": streamer.n_steps if streamer else None,
            "tokens_per_forward": tpf,
            "thought_chars": len(thought),
        }

    print("warmup generation (uncounted) ...")
    generate(passages[0]["ocr_input"])

    results: dict[str, dict] = {}
    for i, p in enumerate(passages):
        out = generate(p["ocr_input"])
        raw = out.pop("_raw")
        results[p["id"]] = out
        print(
            f"[{model_key} {i + 1}/{len(passages)}] {out['seconds']}s, "
            f"{out['tokens_generated']} tok"
            + (f", {out['denoising_steps']} steps" if out["denoising_steps"] else "")
        )
        if smoke:
            print(f"  OCR:  {p['ocr_input'][:200]}")
            print(f"  GOLD: {p['gold'][:200]}")
            print(f"  RAW:  {raw[:300]}")
            print(f"  OUT:  {out['text'][:200]}")

    model = None  # noqa: F841 — drop the closure-captured ref so the GPU frees
    gc.collect()
    torch.cuda.empty_cache()
    return results


# ---------------------------------------------------------------- main


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--mode", choices=["smoke", "full"], default="smoke")
    parser.add_argument("--n", type=int, default=75)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--models", choices=["both", "dg", "g4", "g4moe", "all"], default="both")
    parser.add_argument(
        "--canvas-init",
        action="store_true",
        help="also run DiffusionGemma with the canvas initialised from the OCR text",
    )
    parser.add_argument("--dataset", choices=["bln600", "icdar"], default="bln600")
    parser.add_argument("--cache-examples", action="store_true")
    parser.add_argument("--out-repo", default=None, help="private dataset repo for raw outputs")
    parser.add_argument("--workdir", type=Path, default=Path("data"))
    args = parser.parse_args()

    import torch
    import transformers

    print(f"transformers {transformers.__version__}, torch {torch.__version__}, "
          f"cuda: {torch.cuda.get_device_name(0)}")

    loader = download_bln600 if args.dataset == "bln600" else download_icdar_english
    passages = sample_passages(loader(args.workdir), args.n if args.mode == "full" else 3, args.seed)
    print(f"running {len(passages)} passages, models={args.models}")

    examples = []
    if args.cache_examples:
        icdar = download_icdar_english(args.workdir)
        examples = sample_passages(icdar, N_EXAMPLES, args.seed)
        for e in examples:
            e["id"] = "example/" + e["id"]

    model_keys = {
        "both": ["diffusiongemma", "gemma4"],
        "dg": ["diffusiongemma"],
        "g4": ["gemma4"],
        "g4moe": ["gemma4_moe"],
        "all": ["diffusiongemma", "gemma4", "gemma4_moe"],
    }[args.models]
    if args.canvas_init:
        model_keys.insert(1, "diffusiongemma_canvas")
    all_passages = passages + examples
    outputs: dict[str, dict[str, dict]] = {}
    for key in model_keys:
        try:
            outputs[key] = run_model(key, all_passages, smoke=args.mode == "smoke")
        except Exception as e:  # noqa: BLE001 — a failed condition shouldn't sink the others
            if key == "diffusiongemma_canvas":
                print(f"[{key} FAILED, continuing without it: {type(e).__name__}: {e}]")
            else:
                raise

    meta = {
        "date": time.strftime("%Y-%m-%d"),
        "dataset": args.dataset,
        "n": len(passages),
        "seed": args.seed,
        "max_passage_tokens": MAX_PASSAGE_TOKENS,
        "prompt": PROMPT_TEMPLATE,
        "transformers": transformers.__version__,
        "torch": torch.__version__,
        "gpu": torch.cuda.get_device_name(0),
        "generation": {
            "diffusiongemma": "generation_config defaults (entropy sampler), max_new_tokens=256",
            "diffusiongemma_canvas": "as diffusiongemma, but first canvas seeded with the OCR"
            " text via decoder_input_ids (random tail padding, seed 0)",
            "gemma4": "do_sample=False (greedy), max_new_tokens=256",
            "gemma4_moe": "do_sample=False (greedy), max_new_tokens=256",
        },
    }

    out_path = Path("raw_outputs.jsonl")
    with out_path.open("w") as f:
        for i, p in enumerate(passages):
            record = {
                "id": p["id"],
                "ocr_input": p["ocr_input"],
                "gold": p["gold"],
                "output": {k: outputs[k][p["id"]] for k in model_keys if k in outputs},
            }
            if i == 0:
                record["meta"] = meta
            f.write(json.dumps(record) + "\n")
    print(f"wrote {out_path} ({len(passages)} records)")

    cache_path = None
    if examples:
        cache_path = Path("examples_cached.json")
        cache_path.write_text(
            json.dumps(
                [
                    {
                        "id": e["id"],
                        "ocr_input": e["ocr_input"],
                        "gold": e["gold"],
                        "output": {k: outputs[k][e["id"]] for k in model_keys if k in outputs},
                    }
                    for e in examples
                ],
                indent=2,
            )
        )
        print(f"wrote {cache_path}")

    if args.out_repo:
        from huggingface_hub import HfApi

        api = HfApi()
        api.create_repo(args.out_repo, repo_type="dataset", private=True, exist_ok=True)
        api.upload_file(
            path_or_fileobj=out_path, path_in_repo=out_path.name,
            repo_id=args.out_repo, repo_type="dataset",
        )
        if cache_path:
            api.upload_file(
                path_or_fileobj=cache_path, path_in_repo=cache_path.name,
                repo_id=args.out_repo, repo_type="dataset",
            )
        print(f"uploaded to https://huggingface.co/datasets/{args.out_repo} (private)")


if __name__ == "__main__":
    main()
