"""Train the badger-55 meter reader heads from the published Hugging Face dataset. Downloads: - https://huggingface.co/datasets/S3CUR/badger-55-watermeter - `facebook/dinov2-small` (~85 MB, public) Then trains three heads on the pre-rectified slot crops in the dataset: - `digit_classifier.pt` — general-purpose 10-class digit head. Pooled across slots 4+5+6+7 (each saw all 10 digit classes during data collection), giving the head varied lighting/bezel context per class. At inference it's applied to slots 0–4; slots 0–3 will emit whatever constant their drum happens to be showing, since the source meter's upper drums didn't move during data collection. - `d4d5_predictor90.pt` — 90-bin angular head pooled over `slot in {4,5}` (KL on wrapped-Gaussian soft targets) - `d6d7_predictor90.pt` — same architecture, pooled over `slot in {6,7}`, including the platinum d7 atlas Weights land in `./weights/`. `demo.py` consumes them from there. Usage: python train.py # train all three python train.py --skip-classifier # angular heads only python train.py --epochs 120 """ # A 4th `SinCosSpecialist` head used to train here as a third voter for # the demo. Removed 2026-05-24 — its val MAE was 2-3× worse than # Predictor90 and the consensus never picked it over the primary. from __future__ import annotations import argparse import os import time from pathlib import Path # --- HF download tuning (must be set BEFORE importing huggingface_hub) --- # Xet high-performance multi-stream downloader. Replaces the deprecated # `HF_HUB_ENABLE_HF_TRANSFER` flag in huggingface_hub >= 1.16 (which is # silently ignored — don't use it). os.environ.setdefault('HF_XET_HIGH_PERFORMANCE', '1') # Per-blob HTTP timeout, in seconds. The default is effectively unbounded, # so a blob fetch that gets routed to a slow CloudFront edge can wedge # the entire pull forever. 30s is plenty for a 10-KB JPEG; if a stream # is silent that long it's stuck — kill it and let the retry loop fan # out to a different edge. os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '30') import numpy as np import pandas as pd import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download import models # local module HERE = Path(__file__).parent WEIGHTS = HERE / 'weights' DATASET_ID = 'S3CUR/badger-55-watermeter' # ── dataset download ────────────────────────────────────────────────── def download_slots_parquet(cache_dir: Path | None = None) -> Path: """Fetch the single `slots.parquet` file (JPEG bytes embedded inline). The v2 dataset layout is two root-level parquets — no loose images — so a cold pull is one HTTP request, one ~35 MB stream, one second on a fast link. No retries needed; if the single GET fails huggingface_hub already retries internally.""" t0 = time.time() print(f"[hf] fetching {DATASET_ID}:slots.parquet") local = hf_hub_download( repo_id=DATASET_ID, repo_type='dataset', filename='slots.parquet', cache_dir=str(cache_dir) if cache_dir else None, ) sz = Path(local).stat().st_size / 1024 / 1024 print(f"[hf] cached at {local} ({sz:.1f} MB, {time.time()-t0:.1f}s)") return Path(local) # ── feature extraction ──────────────────────────────────────────────── def _default_device() -> str: return 'cuda' if torch.cuda.is_available() else 'cpu' def load_slot_features(slots_parquet: Path, slot_filter: list[int], device: str | None = None ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]: """Read slots.parquet (JPEG bytes embedded inline), filter to the requested slots, decode each crop, extract DINOv2 features. Returns (feats, thetas, digits, splits, df).""" if device is None: device = _default_device() df = pd.read_parquet(slots_parquet) df = df[df['slot'].isin(slot_filter)].reset_index(drop=True) print(f" filtered to slots={slot_filter}: {len(df)} rows " f"({dict(df['tier'].value_counts())})") import cv2 dino = models.DinoV2(device=device) n = len(df) feats = np.zeros((n, models.DINOV2_DIM), dtype=np.float32) BATCH = 64 t0 = time.time() for i in range(0, n, BATCH): batch_bytes = df['image_bytes'].iloc[i:i+BATCH].tolist() crops = [cv2.imdecode(np.frombuffer(b, np.uint8), cv2.IMREAD_COLOR) for b in batch_bytes] crops = [c for c in crops if c is not None] if len(crops) != len(batch_bytes): raise RuntimeError(f"undecodable crop(s) in batch starting at {i}") arr = models.slot_crops_to_array(crops) feats[i:i+len(crops)] = dino.features(arr).cpu().numpy() if (i // BATCH) % 5 == 0: print(f" features {i+len(crops):5d}/{n} " f"({(i+len(crops))/(time.time()-t0+1e-9):.0f}/s)") return (feats, df['theta_deg'].astype(np.float32).to_numpy(), df['digit'].astype(np.int64).to_numpy(), df['split'].to_numpy(), df) def split_indices(split: np.ndarray): return (split == 'train'), (split == 'val'), (split == 'test') # ── train predictor90 ───────────────────────────────────────────────── def train_predictor90(feats: np.ndarray, thetas: np.ndarray, split: np.ndarray, out_path: Path, epochs=80, lr=3e-3, batch_size=128, sigma_bins=2.0, device: str | None = None, seed=0): if device is None: device = _default_device() torch.manual_seed(seed); np.random.seed(seed) targets = models.wrapped_gaussian_targets(thetas, sigma_bins=sigma_bins) tr, vl, ts = split_indices(split) Xtr = torch.from_numpy(feats[tr]).float().to(device) Ytr = torch.from_numpy(targets[tr]).float().to(device) Xvl = torch.from_numpy(feats[vl]).float().to(device) Tvl = torch.from_numpy(thetas[vl]).float().to(device) Xts = torch.from_numpy(feats[ts]).float().to(device) Tts = torch.from_numpy(thetas[ts]).float().to(device) print(f" train {Xtr.shape[0]} | val {Xvl.shape[0]} | test {Xts.shape[0]}") model = models.Predictor90().to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best = {'val_mae': float('inf'), 'epoch': -1, 'state': None} for ep in range(epochs): model.train() perm = torch.randperm(Xtr.shape[0], device=device) for i in range(0, Xtr.shape[0], batch_size): idx = perm[i:i+batch_size] logits = model(Xtr[idx]) logp = F.log_softmax(logits, dim=-1) loss = F.kl_div(logp, Ytr[idx], reduction='batchmean') opt.zero_grad(); loss.backward(); opt.step() sched.step() model.eval() with torch.no_grad(): vp = models.predictor90_decode(model(Xvl))['theta_deg'] vl_mae = _circ_mae(vp.cpu().numpy(), Tvl.cpu().numpy()) if vl_mae < best['val_mae']: best = {'val_mae': float(vl_mae), 'epoch': ep, 'state': {k: v.clone() for k, v in model.state_dict().items()}} if ep % 5 == 0 or ep == epochs - 1: print(f" ep {ep:3d} | loss {float(loss):.4f} | " f"val MAE {vl_mae:.2f}° (best {best['val_mae']:.2f}° @ ep {best['epoch']})") model.load_state_dict(best['state']) model.eval() with torch.no_grad(): tp = models.predictor90_decode(model(Xts))['theta_deg'] ts_mae = _circ_mae(tp.cpu().numpy(), Tts.cpu().numpy()) out_path.parent.mkdir(parents=True, exist_ok=True) torch.save(best['state'], out_path) print(f" best val MAE {best['val_mae']:.3f}° | test MAE {ts_mae:.3f}°") print(f" saved → {out_path}") def _circ_mae(a, b): d = np.abs(a - b) % 360.0 return float(np.minimum(d, 360.0 - d).mean()) # ── train d4 classifier ─────────────────────────────────────────────── def train_d4_classifier(feats: np.ndarray, digits: np.ndarray, split: np.ndarray, out_path: Path, epochs=60, lr=1e-3, batch_size=128, device: str | None = None, seed=0): if device is None: device = _default_device() torch.manual_seed(seed); np.random.seed(seed) tr, vl, ts = split_indices(split) Xtr = torch.from_numpy(feats[tr]).float().to(device) Ytr = torch.from_numpy(digits[tr]).long().to(device) Xvl = torch.from_numpy(feats[vl]).float().to(device) Yvl = torch.from_numpy(digits[vl]).long().to(device) Xts = torch.from_numpy(feats[ts]).float().to(device) Yts = torch.from_numpy(digits[ts]).long().to(device) print(f" train {Xtr.shape[0]} | val {Xvl.shape[0]} | test {Xts.shape[0]}") model = models.SlotClassifier().to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best = {'val_acc': -1.0, 'epoch': -1, 'state': None} for ep in range(epochs): model.train() perm = torch.randperm(Xtr.shape[0], device=device) for i in range(0, Xtr.shape[0], batch_size): idx = perm[i:i+batch_size] logits = model(Xtr[idx]) loss = F.cross_entropy(logits, Ytr[idx]) opt.zero_grad(); loss.backward(); opt.step() sched.step() model.eval() with torch.no_grad(): vacc = (model(Xvl).argmax(dim=-1) == Yvl).float().mean().item() if vacc > best['val_acc']: best = {'val_acc': vacc, 'epoch': ep, 'state': {k: v.clone() for k, v in model.state_dict().items()}} if ep % 5 == 0 or ep == epochs - 1: print(f" ep {ep:3d} | loss {float(loss):.4f} | " f"val acc {vacc:.4f} (best {best['val_acc']:.4f} @ ep {best['epoch']})") model.load_state_dict(best['state']) model.eval() with torch.no_grad(): tacc = (model(Xts).argmax(dim=-1) == Yts).float().mean().item() out_path.parent.mkdir(parents=True, exist_ok=True) torch.save(best['state'], out_path) print(f" best val acc {best['val_acc']:.4f} | test acc {tacc:.4f}") print(f" saved → {out_path}") # ── per-head training recipe (learned in production 2026-05-24) ─────── # # Bao went through three rounds of sweeps on this same dataset: # # Round 1: 80 epochs at lr 3e-3 for every head (one-size-fits-all default). # Predictor90 heads landed sub-1° val MAE. Specialist heads got stuck # at ~4-5° val MAE — loss curve was still dropping at the last epoch, # i.e. the head hadn't converged. # Round 2: tried 200 epochs at lr 3e-3, then 200 at lr 1e-3, then seed=7 # at the original recipe to disambiguate val-split luck from real # training noise. The 200/1e-3 combo won decisively (d6d7 specialist # dropped from 3.77° → 2.41°, a 36% reduction). # Round 3: after another round of human retags for Geneva-mechanism # margin, the same recipe held: 200/1e-3 specialists, 80/3e-3 # predictor90s. # # So the per-head defaults below encode that lesson. --epochs on the # command line still overrides if you want to experiment. RECIPE = { 'predictor90': {'epochs': 80, 'lr': 3e-3}, # softmax over 90 bins; KL loss 'classifier': {'epochs': 60, 'lr': 1e-3}, # 10-way softmax; already plenty } # ── main ────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser() ap.add_argument('--cache-dir', default=None, help='HF cache root (default: ~/.cache/huggingface)') ap.add_argument('--local-parquet', default=None, help='Skip HF download and read directly from a local ' 'slots.parquet (bytes embedded). NORMAL usage ' 'downloads from HF.') ap.add_argument('--epochs', type=int, default=None, help='Override the per-head epoch defaults from RECIPE. ' 'Use only when experimenting; the defaults are what ' 'the production sweep landed on.') ap.add_argument('--skip-classifier', action='store_true') ap.add_argument('--skip-d4d5', action='store_true') ap.add_argument('--skip-d6d7', action='store_true') ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') args = ap.parse_args() def E(key: str) -> int: return args.epochs if args.epochs is not None else RECIPE[key]['epochs'] def L(key: str) -> float: return RECIPE[key]['lr'] print(f"[start] device={args.device}") if args.local_parquet: ds = Path(args.local_parquet) print(f"[local] using {ds} (skipping HF download)") else: ds = download_slots_parquet( Path(args.cache_dir) if args.cache_dir else None) if not args.skip_d4d5: print(f"\n== train d4d5_predictor90 ({E('predictor90')} ep @ lr {L('predictor90')}) ==") feats, thetas, _, split, _ = load_slot_features(ds, [4, 5], args.device) train_predictor90(feats, thetas, split, WEIGHTS / 'd4d5_predictor90.pt', epochs=E('predictor90'), lr=L('predictor90'), device=args.device) if not args.skip_d6d7: print(f"\n== train d6d7_predictor90 ({E('predictor90')} ep @ lr {L('predictor90')}) ==") feats, thetas, _, split, _ = load_slot_features(ds, [6, 7], args.device) train_predictor90(feats, thetas, split, WEIGHTS / 'd6d7_predictor90.pt', epochs=E('predictor90'), lr=L('predictor90'), device=args.device) if not args.skip_classifier: print(f"\n== train digit_classifier (10-class, pooled d4+d5+d6+d7) " f"({E('classifier')} ep @ lr {L('classifier')}) ==") feats, _, digits, split, _ = load_slot_features(ds, [4, 5, 6, 7], args.device) train_d4_classifier(feats, digits, split, WEIGHTS / 'digit_classifier.pt', epochs=E('classifier'), lr=L('classifier'), device=args.device) print(f"\n[done] weights in {WEIGHTS}") if __name__ == '__main__': main()