| """Train the badger-55 meter reader heads from the published Hugging Face |
| dataset. |
| |
| Downloads: |
| - https://huggingface.co/datasets/S3CUR/badger-55-watermeter |
| - `facebook/dinov2-small` (~85 MB, public) |
| |
| Then trains three heads on the pre-rectified slot crops in the dataset: |
| |
| - `digit_classifier.pt` β general-purpose 10-class digit head. Pooled |
| across slots 4+5+6+7 (each saw all 10 digit |
| classes during data collection), giving the |
| head varied lighting/bezel context per class. |
| At inference it's applied to slots 0β4; |
| slots 0β3 will emit whatever constant their |
| drum happens to be showing, since the source |
| meter's upper drums didn't move during data |
| collection. |
| - `d4d5_predictor90.pt` β 90-bin angular head pooled over `slot in {4,5}` |
| (KL on wrapped-Gaussian soft targets) |
| - `d6d7_predictor90.pt` β same architecture, pooled over `slot in {6,7}`, |
| including the platinum d7 atlas |
| |
| Weights land in `./weights/`. `demo.py` consumes them from there. |
| |
| Usage: |
| python train.py # train all three |
| python train.py --skip-classifier # angular heads only |
| python train.py --epochs 120 |
| """ |
| |
| |
| |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import time |
| from pathlib import Path |
|
|
| |
| |
| |
| |
| os.environ.setdefault('HF_XET_HIGH_PERFORMANCE', '1') |
| |
| |
| |
| |
| |
| os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '30') |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
|
|
| import models |
|
|
|
|
| HERE = Path(__file__).parent |
| WEIGHTS = HERE / 'weights' |
| DATASET_ID = 'S3CUR/badger-55-watermeter' |
|
|
|
|
| |
| def download_slots_parquet(cache_dir: Path | None = None) -> Path: |
| """Fetch the single `slots.parquet` file (JPEG bytes embedded inline). |
| |
| The v2 dataset layout is two root-level parquets β no loose images β |
| so a cold pull is one HTTP request, one ~35 MB stream, one second on |
| a fast link. No retries needed; if the single GET fails huggingface_hub |
| already retries internally.""" |
| t0 = time.time() |
| print(f"[hf] fetching {DATASET_ID}:slots.parquet") |
| local = hf_hub_download( |
| repo_id=DATASET_ID, repo_type='dataset', |
| filename='slots.parquet', |
| cache_dir=str(cache_dir) if cache_dir else None, |
| ) |
| sz = Path(local).stat().st_size / 1024 / 1024 |
| print(f"[hf] cached at {local} ({sz:.1f} MB, {time.time()-t0:.1f}s)") |
| return Path(local) |
|
|
|
|
| |
| def _default_device() -> str: |
| return 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| def load_slot_features(slots_parquet: Path, slot_filter: list[int], |
| device: str | None = None |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]: |
| """Read slots.parquet (JPEG bytes embedded inline), filter to the |
| requested slots, decode each crop, extract DINOv2 features. Returns |
| (feats, thetas, digits, splits, df).""" |
| if device is None: device = _default_device() |
| df = pd.read_parquet(slots_parquet) |
| df = df[df['slot'].isin(slot_filter)].reset_index(drop=True) |
| print(f" filtered to slots={slot_filter}: {len(df)} rows " |
| f"({dict(df['tier'].value_counts())})") |
| import cv2 |
| dino = models.DinoV2(device=device) |
| n = len(df) |
| feats = np.zeros((n, models.DINOV2_DIM), dtype=np.float32) |
| BATCH = 64 |
| t0 = time.time() |
| for i in range(0, n, BATCH): |
| batch_bytes = df['image_bytes'].iloc[i:i+BATCH].tolist() |
| crops = [cv2.imdecode(np.frombuffer(b, np.uint8), cv2.IMREAD_COLOR) |
| for b in batch_bytes] |
| crops = [c for c in crops if c is not None] |
| if len(crops) != len(batch_bytes): |
| raise RuntimeError(f"undecodable crop(s) in batch starting at {i}") |
| arr = models.slot_crops_to_array(crops) |
| feats[i:i+len(crops)] = dino.features(arr).cpu().numpy() |
| if (i // BATCH) % 5 == 0: |
| print(f" features {i+len(crops):5d}/{n} " |
| f"({(i+len(crops))/(time.time()-t0+1e-9):.0f}/s)") |
| return (feats, |
| df['theta_deg'].astype(np.float32).to_numpy(), |
| df['digit'].astype(np.int64).to_numpy(), |
| df['split'].to_numpy(), |
| df) |
|
|
|
|
| def split_indices(split: np.ndarray): |
| return (split == 'train'), (split == 'val'), (split == 'test') |
|
|
|
|
| |
| def train_predictor90(feats: np.ndarray, thetas: np.ndarray, split: np.ndarray, |
| out_path: Path, epochs=80, lr=3e-3, batch_size=128, |
| sigma_bins=2.0, device: str | None = None, seed=0): |
| if device is None: device = _default_device() |
| torch.manual_seed(seed); np.random.seed(seed) |
| targets = models.wrapped_gaussian_targets(thetas, sigma_bins=sigma_bins) |
| tr, vl, ts = split_indices(split) |
| Xtr = torch.from_numpy(feats[tr]).float().to(device) |
| Ytr = torch.from_numpy(targets[tr]).float().to(device) |
| Xvl = torch.from_numpy(feats[vl]).float().to(device) |
| Tvl = torch.from_numpy(thetas[vl]).float().to(device) |
| Xts = torch.from_numpy(feats[ts]).float().to(device) |
| Tts = torch.from_numpy(thetas[ts]).float().to(device) |
| print(f" train {Xtr.shape[0]} | val {Xvl.shape[0]} | test {Xts.shape[0]}") |
|
|
| model = models.Predictor90().to(device) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| best = {'val_mae': float('inf'), 'epoch': -1, 'state': None} |
| for ep in range(epochs): |
| model.train() |
| perm = torch.randperm(Xtr.shape[0], device=device) |
| for i in range(0, Xtr.shape[0], batch_size): |
| idx = perm[i:i+batch_size] |
| logits = model(Xtr[idx]) |
| logp = F.log_softmax(logits, dim=-1) |
| loss = F.kl_div(logp, Ytr[idx], reduction='batchmean') |
| opt.zero_grad(); loss.backward(); opt.step() |
| sched.step() |
| model.eval() |
| with torch.no_grad(): |
| vp = models.predictor90_decode(model(Xvl))['theta_deg'] |
| vl_mae = _circ_mae(vp.cpu().numpy(), Tvl.cpu().numpy()) |
| if vl_mae < best['val_mae']: |
| best = {'val_mae': float(vl_mae), 'epoch': ep, |
| 'state': {k: v.clone() for k, v in model.state_dict().items()}} |
| if ep % 5 == 0 or ep == epochs - 1: |
| print(f" ep {ep:3d} | loss {float(loss):.4f} | " |
| f"val MAE {vl_mae:.2f}Β° (best {best['val_mae']:.2f}Β° @ ep {best['epoch']})") |
| model.load_state_dict(best['state']) |
| model.eval() |
| with torch.no_grad(): |
| tp = models.predictor90_decode(model(Xts))['theta_deg'] |
| ts_mae = _circ_mae(tp.cpu().numpy(), Tts.cpu().numpy()) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(best['state'], out_path) |
| print(f" best val MAE {best['val_mae']:.3f}Β° | test MAE {ts_mae:.3f}Β°") |
| print(f" saved β {out_path}") |
|
|
|
|
| def _circ_mae(a, b): |
| d = np.abs(a - b) % 360.0 |
| return float(np.minimum(d, 360.0 - d).mean()) |
|
|
|
|
| |
| def train_d4_classifier(feats: np.ndarray, digits: np.ndarray, split: np.ndarray, |
| out_path: Path, epochs=60, lr=1e-3, batch_size=128, |
| device: str | None = None, seed=0): |
| if device is None: device = _default_device() |
| torch.manual_seed(seed); np.random.seed(seed) |
| tr, vl, ts = split_indices(split) |
| Xtr = torch.from_numpy(feats[tr]).float().to(device) |
| Ytr = torch.from_numpy(digits[tr]).long().to(device) |
| Xvl = torch.from_numpy(feats[vl]).float().to(device) |
| Yvl = torch.from_numpy(digits[vl]).long().to(device) |
| Xts = torch.from_numpy(feats[ts]).float().to(device) |
| Yts = torch.from_numpy(digits[ts]).long().to(device) |
| print(f" train {Xtr.shape[0]} | val {Xvl.shape[0]} | test {Xts.shape[0]}") |
|
|
| model = models.SlotClassifier().to(device) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| best = {'val_acc': -1.0, 'epoch': -1, 'state': None} |
| for ep in range(epochs): |
| model.train() |
| perm = torch.randperm(Xtr.shape[0], device=device) |
| for i in range(0, Xtr.shape[0], batch_size): |
| idx = perm[i:i+batch_size] |
| logits = model(Xtr[idx]) |
| loss = F.cross_entropy(logits, Ytr[idx]) |
| opt.zero_grad(); loss.backward(); opt.step() |
| sched.step() |
| model.eval() |
| with torch.no_grad(): |
| vacc = (model(Xvl).argmax(dim=-1) == Yvl).float().mean().item() |
| if vacc > best['val_acc']: |
| best = {'val_acc': vacc, 'epoch': ep, |
| 'state': {k: v.clone() for k, v in model.state_dict().items()}} |
| if ep % 5 == 0 or ep == epochs - 1: |
| print(f" ep {ep:3d} | loss {float(loss):.4f} | " |
| f"val acc {vacc:.4f} (best {best['val_acc']:.4f} @ ep {best['epoch']})") |
| model.load_state_dict(best['state']) |
| model.eval() |
| with torch.no_grad(): |
| tacc = (model(Xts).argmax(dim=-1) == Yts).float().mean().item() |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(best['state'], out_path) |
| print(f" best val acc {best['val_acc']:.4f} | test acc {tacc:.4f}") |
| print(f" saved β {out_path}") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| RECIPE = { |
| 'predictor90': {'epochs': 80, 'lr': 3e-3}, |
| 'classifier': {'epochs': 60, 'lr': 1e-3}, |
| } |
|
|
|
|
| |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--cache-dir', default=None, |
| help='HF cache root (default: ~/.cache/huggingface)') |
| ap.add_argument('--local-parquet', default=None, |
| help='Skip HF download and read directly from a local ' |
| 'slots.parquet (bytes embedded). NORMAL usage ' |
| 'downloads from HF.') |
| ap.add_argument('--epochs', type=int, default=None, |
| help='Override the per-head epoch defaults from RECIPE. ' |
| 'Use only when experimenting; the defaults are what ' |
| 'the production sweep landed on.') |
| ap.add_argument('--skip-classifier', action='store_true') |
| ap.add_argument('--skip-d4d5', action='store_true') |
| ap.add_argument('--skip-d6d7', action='store_true') |
| ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') |
| args = ap.parse_args() |
|
|
| def E(key: str) -> int: |
| return args.epochs if args.epochs is not None else RECIPE[key]['epochs'] |
| def L(key: str) -> float: |
| return RECIPE[key]['lr'] |
|
|
| print(f"[start] device={args.device}") |
| if args.local_parquet: |
| ds = Path(args.local_parquet) |
| print(f"[local] using {ds} (skipping HF download)") |
| else: |
| ds = download_slots_parquet( |
| Path(args.cache_dir) if args.cache_dir else None) |
|
|
| if not args.skip_d4d5: |
| print(f"\n== train d4d5_predictor90 ({E('predictor90')} ep @ lr {L('predictor90')}) ==") |
| feats, thetas, _, split, _ = load_slot_features(ds, [4, 5], args.device) |
| train_predictor90(feats, thetas, split, |
| WEIGHTS / 'd4d5_predictor90.pt', |
| epochs=E('predictor90'), lr=L('predictor90'), |
| device=args.device) |
|
|
| if not args.skip_d6d7: |
| print(f"\n== train d6d7_predictor90 ({E('predictor90')} ep @ lr {L('predictor90')}) ==") |
| feats, thetas, _, split, _ = load_slot_features(ds, [6, 7], args.device) |
| train_predictor90(feats, thetas, split, |
| WEIGHTS / 'd6d7_predictor90.pt', |
| epochs=E('predictor90'), lr=L('predictor90'), |
| device=args.device) |
|
|
| if not args.skip_classifier: |
| print(f"\n== train digit_classifier (10-class, pooled d4+d5+d6+d7) " |
| f"({E('classifier')} ep @ lr {L('classifier')}) ==") |
| feats, _, digits, split, _ = load_slot_features(ds, [4, 5, 6, 7], |
| args.device) |
| train_d4_classifier(feats, digits, split, |
| WEIGHTS / 'digit_classifier.pt', |
| epochs=E('classifier'), lr=L('classifier'), |
| device=args.device) |
|
|
| print(f"\n[done] weights in {WEIGHTS}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|