scrubdata / eval /generalization.py
OpenAI Codex
deploy: add sponsor:openai tag (Best Use of Codex) + Codex-hardened build
16dc556
Raw
History Blame Contribute Delete
8.52 kB
"""D1 — the GENERALIZATION metric: held-out-source real-error evaluation.
The wide-suite REAL slice mixes sources whose pairs are IN the champion's training mix
(hospital/beers/movies_1 -> mixA), so it part-measures memorization. This metric fixes
that and one more honesty problem:
* HELD-OUT SOURCES ONLY: a model is scored only on real-error benchmarks whose pairs
were never used to train it. The split is explicit and committed (TRAIN_SOURCES);
new harvested sources must be assigned to exactly one side.
* ERROR-CLASS BREAKDOWN: benchmark errors split by the SAME variant gate the training
derivation uses (training.real_data._is_variant — one source of truth). A
canonicalization system claims competence on the VARIANT class (typos / casing /
aliases); imputation-class errors (missing or non-variant rewrites) are reported,
never hidden, but a system that abstains on them is behaving correctly.
Headline numbers per system:
GEN-F1 churn-neutral F1 over ALL errors, macro over held-out sources
VARIANT-RECALL share of variant-class errors repaired (claimed competence)
VARIANT-PREC of committed changes on variant cells, share correct
damage clean cells corrupted (churn-neutral)
DISCLOSED class imperfection: the string-variant gate over-counts on flights —
single-digit time differences ('7:59 p.m.' vs '7:58 p.m.') pass the similarity
threshold but are cross-source VALUE disagreements (need per-entity cross-row
voting, a different capability), not surface canonicalization. ~950 of flights'
1049 "variant" errors are of this kind; treat flights' variant-recall as a
lower-bound stress number, not addressable headroom.
uv run python -m eval.generalization # grounded heuristic baseline
"""
from __future__ import annotations
import argparse
import json
from scrubdata.executor import apply_plan
from scrubdata.planner import mock_plan
from training.real_data import _is_variant
from .metrics import _cell_equal
from .run_real_multi import _cell_only, _fetch, _sem_equal, score
# pairs used to train the current champion (v6 = mixA) — anything here is OFF-LIMITS
# for generalization scoring of that model. Update per training run.
TRAIN_SOURCES = {"v6": {"hospital", "beers", "movies_1"}}
# held-out real-error sources. Harvested D1 sources get appended here OR to the
# training side — never both. ed2_restaurants (stage-2 harvest): real NYC-restaurant
# typos, in-regime, EVAL-ONLY — its sibling domain source (fodors_zagats) trains, so
# this measures cross-source same-domain transfer. dblp_scholar was REJECTED as an
# eval source: its gold systematically prefers the opposite case convention from the
# dirty side (Scholar lowercase vs DBLP Title Case), which measures convention
# preference, not cleaning — the artifact this metric is designed against.
EVAL_SOURCES = ["flights", "rayyan", "ed2_restaurants"]
def variant_breakdown(dirty, clean, out) -> dict:
"""Split benchmark errors by class and count repairs per class (churn-neutral)."""
n = min(len(dirty), len(out), len(clean))
c = {"variant_errors": 0, "variant_fixed": 0, "variant_changed": 0,
"variant_good": 0, "other_errors": 0, "other_fixed": 0}
for j, col in enumerate(dirty.columns):
present = col in out.columns
for i in range(n):
dv, cv = dirty.iat[i, j], clean.iat[i, j]
if _cell_equal(dv, cv):
continue # not a benchmark error
ov = out.iloc[i][col] if present else dv
chg = present and not _cell_equal(ov, dv)
if chg and _sem_equal(ov, dv) and not _cell_equal(ov, cv):
chg = False # churn: ignore
fixed = _cell_equal(ov, cv) or (_sem_equal(ov, cv) and chg)
is_variant = (str(dv).strip() and str(cv).strip()
and _is_variant(str(dv), str(cv)))
if is_variant:
c["variant_errors"] += 1
c["variant_fixed"] += int(fixed)
if chg:
c["variant_changed"] += 1
c["variant_good"] += int(_sem_equal(ov, cv))
else:
c["other_errors"] += 1
c["other_fixed"] += int(fixed)
return c
def evaluate_generalization(planner, sources=None, label: str = "system") -> dict:
sources = sources or EVAL_SOURCES
rows = []
for name in sources:
# FULL tables, no truncation — ed2_restaurants' real errors are concentrated
# outside the first 2k rows (_raha_pair's head(2000) hid 473 of 477).
dirty, clean = _fetch(name)
cleaned, _ = apply_plan(dirty, _cell_only(planner(dirty)))
m = score(dirty, clean, cleaned)
b = variant_breakdown(dirty, clean, cleaned)
rows.append({"source": name, **{k: m[k] for k in
("f1", "precision", "recall", "damage")}, **b})
print(f" {name:<10} F1={m['f1']:.3f} dmg={m['damage']:.3f} | variant: "
f"{b['variant_fixed']}/{b['variant_errors']} fixed, "
f"{b['variant_good']}/{b['variant_changed']} changes good | "
f"other: {b['other_fixed']}/{b['other_errors']}", flush=True)
return _aggregate(rows, sources, label)
def evaluate_captured_union(plans: dict, sources, label: str, tau: float = 0.5) -> dict:
"""Score the SHIPPED pipeline from captured raw model plans (Modal --capture):
per source, verify(tau) the captured plan, union with the grounded heuristic —
byte-identical composition to scrubdata/active.py."""
from scrubdata.verifier import union_plans, verify_plan
def planner_for(name):
def planner(df, *_):
return union_plans(verify_plan(df, plans[name], tau=tau), mock_plan(df))
return planner
rows = []
for name in sources:
dirty, clean = _fetch(name)
cleaned, _ = apply_plan(dirty, _cell_only(planner_for(name)(dirty)))
m = score(dirty, clean, cleaned)
b = variant_breakdown(dirty, clean, cleaned)
rows.append({"source": name, **{k: m[k] for k in
("f1", "precision", "recall", "damage")}, **b})
print(f" {name:<16} F1={m['f1']:.3f} dmg={m['damage']:.3f} | variant: "
f"{b['variant_fixed']}/{b['variant_errors']} fixed", flush=True)
return _aggregate(rows, sources, label)
def _aggregate(rows, sources, label) -> dict:
def mean(xs):
xs = list(xs)
return sum(xs) / len(xs) if xs else 0.0
def rate(num, den):
return num / den if den else 0.0
out = {
"system": label, "sources": list(sources),
"gen_f1": mean(r["f1"] for r in rows),
"variant_recall": mean(rate(r["variant_fixed"], r["variant_errors"]) for r in rows),
"variant_precision": mean(rate(r["variant_good"], r["variant_changed"])
if r["variant_changed"] else 1.0 for r in rows),
"other_recall": mean(rate(r["other_fixed"], r["other_errors"]) for r in rows),
"damage": mean(r["damage"] for r in rows),
"per_source": rows,
}
print(f"{label}: GEN-F1={out['gen_f1']:.3f} VARIANT-RECALL={out['variant_recall']:.3f} "
f"VARIANT-PREC={out['variant_precision']:.3f} dmg={out['damage']:.3f}")
return out
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--sources", default=",".join(EVAL_SOURCES))
ap.add_argument("--plans", default=None,
help="JSON file {source: captured raw model plan} -> score the "
"shipped union pipeline instead of the local baselines")
ap.add_argument("--label", default="captured union")
ap.add_argument("--out", default="eval/results/generalization_baseline.json")
args = ap.parse_args()
sources = args.sources.split(",")
if args.plans:
plans = json.load(open(args.plans))
results = [evaluate_captured_union(plans, sources, args.label)]
else:
results = [
evaluate_generalization(mock_plan, sources, "grounded heuristic"),
evaluate_generalization(
lambda df: {"table_operations": [], "columns": [], "flags": []},
sources, "no-op"),
]
json.dump(results, open(args.out, "w"), indent=1)
print(f"written to {args.out}")
if __name__ == "__main__":
main()