File size: 6,225 Bytes
3800bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Model architectures + DINOv2 feature extraction for the badger-55
meter reader. Two heads are trained:

  - `SlotClassifier` β€” 10-class digit classifier per slot (used for d0..d4
    in this demo; the upper drums had constant ground-truth labels during
    data collection so a classifier trained on the pooled set only learns
    the constant for those slots).
  - `Predictor90` β€” 90-bin angular classifier over a slot's drum rotation,
    trained with wrapped-Gaussian soft targets. Used for d5..d7. The
    decode picks a continuous theta via the circular mean of the softmax,
    giving sub-bin precision.

A `SinCosSpecialist` head used to live here as a third voter. It was
removed 2026-05-24 β€” its val MAE was 2-3Γ— worse than Predictor90, the
consensus never picked it over the primary, and it was just noise in
the demo render."""
from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoImageProcessor


DINOV2_ID = 'facebook/dinov2-small'
DINOV2_DIM = 384
DINOV2_SIZE = 224          # input resolution
N_BINS = 90
BIN_DEG = 360.0 / N_BINS


# ── architectures ─────────────────────────────────────────────────────
class SlotClassifier(nn.Module):
    """Per-slot 10-class digit head. 384 -> 128 -> 10."""
    def __init__(self, in_dim=DINOV2_DIM, hidden=128, dropout=0.15):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden, 10),
        )
    def forward(self, x): return self.net(x)


class Predictor90(nn.Module):
    """90-bin angular classifier. 384 -> 128 -> 128 -> 90 raw logits.
    Softmax + circular-mean decode is the caller's job (see
    `predictor90_decode`)."""
    def __init__(self, in_dim=DINOV2_DIM, hidden=128, dropout=0.1, n_bins=N_BINS):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden, hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden, n_bins),
        )
    def forward(self, x): return self.mlp(x)


# ── soft-target helpers ───────────────────────────────────────────────
def wrapped_gaussian_targets(theta_deg, n_bins=N_BINS, sigma_bins=2.0):
    """Soft targets for the Predictor90 head β€” wrapped Gaussian on the
    circle, integrated per bin. Accepts numpy array or scalar; returns
    (N, n_bins) or (n_bins,) accordingly."""
    bin_deg = 360.0 / n_bins
    centers = np.arange(n_bins, dtype=np.float32) * bin_deg + bin_deg / 2
    t = np.atleast_1d(np.asarray(theta_deg, dtype=np.float32))
    d = np.abs(centers[None, :] - t[:, None])
    d = np.minimum(d, 360.0 - d)
    d_bins = d / bin_deg
    target = np.exp(-(d_bins ** 2) / (2.0 * sigma_bins ** 2))
    target = target / target.sum(axis=-1, keepdims=True)
    return target.squeeze() if np.isscalar(theta_deg) else target


def predictor90_decode(logits: torch.Tensor, n_bins=N_BINS):
    """Decode (B, n_bins) logits to {theta_deg, digit, top1_prob,
    entropy}. theta uses the circular mean of the softmax for sub-bin
    precision."""
    bin_deg = 360.0 / n_bins
    probs = F.softmax(logits, dim=-1)
    centers_deg = (torch.arange(n_bins, device=logits.device, dtype=logits.dtype)
                    * bin_deg + bin_deg / 2.0)
    centers_rad = centers_deg * (np.pi / 180.0)
    sin_m = (probs * torch.sin(centers_rad)).sum(dim=-1)
    cos_m = (probs * torch.cos(centers_rad)).sum(dim=-1)
    theta = (torch.atan2(sin_m, cos_m) * 180.0 / np.pi) % 360.0
    top1_prob, _ = probs.max(dim=-1)
    # entropy in nats
    logp = torch.log(probs.clamp_min(1e-12))
    entropy = -(probs * logp).sum(dim=-1)
    digit = (theta // 36.0).long() % 10
    return {'theta_deg': theta, 'digit': digit, 'top1_prob': top1_prob,
            'entropy': entropy, 'probs': probs}


# ── DINOv2 feature extractor ──────────────────────────────────────────
class DinoV2:
    """Thin wrapper around the public `facebook/dinov2-small` HF model.
    Returns CLS-token features of shape `(N, 384)`. Frozen β€” no
    fine-tuning."""
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD  = (0.229, 0.224, 0.225)

    def __init__(self, device: str | torch.device | None = None):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(device)
        self.proc = AutoImageProcessor.from_pretrained(DINOV2_ID)
        self.model = AutoModel.from_pretrained(DINOV2_ID).to(self.device).eval()
        self._mean = torch.tensor(self.IMAGENET_MEAN,
                                   device=self.device).view(1, 3, 1, 1)
        self._std  = torch.tensor(self.IMAGENET_STD,
                                   device=self.device).view(1, 3, 1, 1)

    @torch.no_grad()
    def features(self, slot_array_chw_01: np.ndarray) -> torch.Tensor:
        """Input: `(N, 3, 224, 224)` float32 in `[0, 1]`, RGB.
        Output: `(N, 384)` features on the model's device."""
        x = torch.from_numpy(slot_array_chw_01).to(self.device)
        x = (x - self._mean) / self._std
        out = self.model(pixel_values=x).last_hidden_state[:, 0, :]
        return out


def slot_crops_to_array(slot_bgrs: list[np.ndarray]) -> np.ndarray:
    """Convert a list of BGR slot crops (any spatial size) into the
    `(N, 3, 224, 224)` float32 [0,1] RGB array DinoV2 expects."""
    import cv2
    out = np.zeros((len(slot_bgrs), 3, DINOV2_SIZE, DINOV2_SIZE), dtype=np.float32)
    for i, bgr in enumerate(slot_bgrs):
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        if rgb.shape[:2] != (DINOV2_SIZE, DINOV2_SIZE):
            rgb = cv2.resize(rgb, (DINOV2_SIZE, DINOV2_SIZE),
                              interpolation=cv2.INTER_LINEAR)
        out[i] = rgb.transpose(2, 0, 1).astype(np.float32) / 255.0
    return out