bingyan user Cursor commited on
Commit
3f4e2ae
·
1 Parent(s): 7c95286

Joint Phase-2 OOD heads + missing runtime modules

Browse files

- ec_joint_best.pt + tpd_joint_best.pt now contain the OOD-equipped
joint checkpoints (val AUROC 0.965 EC, 0.967 TPD); From-Image tab
will show OOD warnings under joint mode just like the headline path.
- Add image_encoder.py (imported by multi_mechanism_model.py and
tpd_model.py for the Phase-2 joint encoder) and image_preprocessing.py
(used by app.py for auto-crop + gridline removal preview).

Co-authored-by: Cursor <cursoragent@cursor.com>

checkpoints/ec_joint_best.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c38583b033e633973ee1f63eddec4259ef103494f474f6044f949951e8eb321e
3
- size 42047802
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25058db0f5e010c594fe03f2076a53eaac276820aba90cc28c6a32a4a7f32582
3
+ size 42208892
checkpoints/tpd_joint_best.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:592fe7cb6005777ff66edb4c34942c98ea6b42ac4be8fed10dd8826ac4ceff4d
3
- size 52644074
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c47748dea284999cd34e98046576b526d2706e7e9315e5a53ac42f8087596ac
3
+ size 52833544
image_encoder.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image encoder for the image-input TRACE variants (CV and TPD).
3
+
4
+ Replaces the per-scan 1-D `SignalEncoder` with a small 2-D CNN that maps a
5
+ single rasterized plot image (grayscale, 224x224 by default) to a context
6
+ vector of the same dimensionality `d_context`. Everything downstream of the
7
+ per-scan branch (`cv_augment`, SAB, PMA, classifier, flow heads, OOD head)
8
+ stays unchanged, so this is a drop-in replacement.
9
+
10
+ Input: [B, 1, H, W] grayscale plot image, values in [0, 1].
11
+ Output: [B, d_context]
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class ImageEncoder(nn.Module):
19
+ """Small 2-D CNN encoder for plot images.
20
+
21
+ Architecture (~3.5M params at default settings):
22
+ Conv 1->32 k=7 s=2 GELU + BN -> 112x112
23
+ Conv 32->64 k=5 s=2 GELU + BN -> 56x56
24
+ Conv 64->96 k=3 s=2 GELU + BN -> 28x28
25
+ Conv 96->128 k=3 s=2 GELU + BN -> 14x14
26
+ Conv 128->d_model k=3 s=2 GELU + BN -> 7x7
27
+ Adaptive avg pool -> [B, d_model]
28
+ MLP d_model -> d_context
29
+
30
+ Designed to be light enough to train from scratch alongside the existing
31
+ classifier and flow heads, while still having the receptive field needed
32
+ to read curve shape across the whole image.
33
+ """
34
+
35
+ def __init__(self, in_channels: int = 1, d_model: int = 128,
36
+ d_context: int = 128, dropout: float = 0.1):
37
+ super().__init__()
38
+ self.in_channels = in_channels
39
+ self.d_model = d_model
40
+ self.d_context = d_context
41
+
42
+ def block(c_in, c_out, k, s):
43
+ return nn.Sequential(
44
+ nn.Conv2d(c_in, c_out, kernel_size=k, stride=s,
45
+ padding=k // 2, bias=False),
46
+ nn.BatchNorm2d(c_out),
47
+ nn.GELU(),
48
+ )
49
+
50
+ self.stem = nn.Sequential(
51
+ block(in_channels, 32, k=7, s=2),
52
+ block(32, 64, k=5, s=2),
53
+ block(64, 96, k=3, s=2),
54
+ block(96, 128, k=3, s=2),
55
+ block(128, d_model, k=3, s=2),
56
+ )
57
+ self.pool = nn.AdaptiveAvgPool2d(1)
58
+ self.proj = nn.Sequential(
59
+ nn.Linear(d_model, d_context),
60
+ nn.GELU(),
61
+ nn.Dropout(dropout),
62
+ nn.Linear(d_context, d_context),
63
+ )
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ """
67
+ Args:
68
+ x: [B, in_channels, H, W] image tensor in [0, 1].
69
+
70
+ Returns:
71
+ context: [B, d_context]
72
+ """
73
+ h = self.stem(x)
74
+ h = self.pool(h).flatten(1)
75
+ return self.proj(h)
76
+
77
+
78
+ def count_parameters(module: nn.Module) -> int:
79
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ enc = ImageEncoder(in_channels=1, d_model=128, d_context=128)
84
+ x = torch.zeros(2, 1, 224, 224)
85
+ out = enc(x)
86
+ print(f"ImageEncoder params: {count_parameters(enc):,}")
87
+ print(f"Input shape: {tuple(x.shape)}")
88
+ print(f"Output shape: {tuple(out.shape)}")
image_preprocessing.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image preprocessing for image-input TRACE on real-world uploads.
3
+
4
+ Real user uploads (paper-figure crops, software screenshots, photos of
5
+ lab monitors) live in a much wider image distribution than the rendered
6
+ training PNGs. This module produces a cleaned grayscale 224x224 PIL image
7
+ that looks closer to the training distribution before it enters the
8
+ image-mode CNN.
9
+
10
+ Three stages, all PIL-in / PIL-out:
11
+
12
+ 1. crop_to_plot_region -- OCR-based detection of the inner plot
13
+ bounding box; crops out browser chrome,
14
+ paper captions, side panels.
15
+ 2. remove_gridlines_and_background
16
+ -- adaptive threshold + morphological line
17
+ detection to suppress thin gridlines and
18
+ normalize the background to white.
19
+ 3. prepare_for_image_mode
20
+ -- orchestrator: crop -> clean -> resize.
21
+
22
+ All heavy CV deps (`cv2`, `easyocr`) are imported lazily so the module
23
+ loads cleanly in environments that lack them; in that case the relevant
24
+ function is a no-op and `meta['was_*']` reports False.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from typing import Dict, Optional, Tuple
30
+
31
+ import numpy as np
32
+ from PIL import Image, ImageOps
33
+
34
+
35
+ # --------------------------------------------------------------------------
36
+ # Plot-region cropping (OCR-based)
37
+ # --------------------------------------------------------------------------
38
+
39
+ def _detect_label_positions(image_array: np.ndarray):
40
+ """Run OCR and return raw (cx, cy, val) tuples for every numeric label.
41
+
42
+ Mirrors the OCR pass in `digitizer.auto_detect_axis_bounds` but exposes
43
+ the per-label pixel positions, which we need to locate the inner plot
44
+ bounding box (right of y-labels, above x-labels).
45
+
46
+ Returns ([], None) if easyocr is unavailable or finds <4 numeric labels.
47
+ """
48
+ try:
49
+ import easyocr
50
+ except ImportError:
51
+ return [], None
52
+ import re
53
+
54
+ if image_array.ndim == 3 and image_array.shape[2] == 4:
55
+ image_array = image_array[:, :, :3]
56
+
57
+ H, W = image_array.shape[:2]
58
+ reader = easyocr.Reader(["en"], gpu=False, verbose=False)
59
+ try:
60
+ results = reader.readtext(image_array, detail=1)
61
+ except Exception:
62
+ return [], None
63
+
64
+ _NUM_RE = re.compile(r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$")
65
+ detections = []
66
+ for bbox, text, conf in results:
67
+ cleaned = (text.strip().replace(" ", "")
68
+ .replace("−", "-").replace("–", "-").replace("~", "-"))
69
+ if not _NUM_RE.match(cleaned):
70
+ continue
71
+ try:
72
+ float(cleaned)
73
+ except ValueError:
74
+ continue
75
+ if conf < 0.2:
76
+ continue
77
+ cx = float(np.mean([p[0] for p in bbox]))
78
+ cy = float(np.mean([p[1] for p in bbox]))
79
+ detections.append((cx, cy, float(cleaned.replace("-", "-"))))
80
+
81
+ if len(detections) < 4:
82
+ return [], None
83
+ return detections, (H, W)
84
+
85
+
86
+ def _plot_bbox_from_detections(detections, hw, margin_frac: float = 0.02):
87
+ """Compute inner-plot bounding box (left, top, right, bottom) in pixels
88
+ from raw OCR label detections.
89
+
90
+ Heuristic:
91
+ - y-axis labels live in the left third of the image
92
+ -> plot_left = max cx among y-labels + margin
93
+ - x-axis labels live in the bottom third of the image
94
+ -> plot_bottom = min cy among x-labels - margin
95
+ - plot_right roughly = max cx among x-labels + margin (fallback to W)
96
+ - plot_top roughly = min cy among y-labels - margin (fallback to 0)
97
+
98
+ Returns (left, top, right, bottom) ints, or None if heuristic fails.
99
+ """
100
+ H, W = hw
101
+ margin = int(margin_frac * max(H, W))
102
+
103
+ y_label_cxs = [cx for cx, cy, _ in detections if cx < W * 0.30]
104
+ y_label_cys = [cy for cx, cy, _ in detections if cx < W * 0.30]
105
+ x_label_cxs = [cx for cx, cy, _ in detections if cy > H * 0.65]
106
+ x_label_cys = [cy for cx, cy, _ in detections if cy > H * 0.65]
107
+
108
+ if not y_label_cxs or not x_label_cys:
109
+ return None
110
+
111
+ plot_left = int(max(y_label_cxs) + margin)
112
+ plot_bottom = int(min(x_label_cys) - margin)
113
+ plot_right = int(max(x_label_cxs) + margin) if x_label_cxs else W
114
+ plot_top = int(min(y_label_cys) - margin) if y_label_cys else 0
115
+
116
+ plot_left = max(0, min(plot_left, W - 1))
117
+ plot_right = max(plot_left + 1, min(plot_right, W))
118
+ plot_top = max(0, min(plot_top, H - 1))
119
+ plot_bottom = max(plot_top + 1, min(plot_bottom, H))
120
+
121
+ if plot_right - plot_left < 32 or plot_bottom - plot_top < 32:
122
+ return None
123
+ return (plot_left, plot_top, plot_right, plot_bottom)
124
+
125
+
126
+ def crop_to_plot_region(pil_image: Image.Image,
127
+ margin_frac: float = 0.02,
128
+ ) -> Tuple[Image.Image, Optional[Tuple[int, int, int, int]]]:
129
+ """Detect the inner plot bbox via OCR and crop to it.
130
+
131
+ Args:
132
+ pil_image: input PIL image (any mode).
133
+ margin_frac: small padding around the detected plot region as a
134
+ fraction of max(H, W).
135
+
136
+ Returns:
137
+ (cropped_pil, bbox) where bbox is (left, top, right, bottom) ints
138
+ or None if OCR-based detection failed (in which case
139
+ cropped_pil == pil_image).
140
+ """
141
+ arr = np.asarray(pil_image.convert("RGB"))
142
+ dets, hw = _detect_label_positions(arr)
143
+ if not dets or hw is None:
144
+ return pil_image, None
145
+ bbox = _plot_bbox_from_detections(dets, hw, margin_frac=margin_frac)
146
+ if bbox is None:
147
+ return pil_image, None
148
+ cropped = pil_image.crop(bbox)
149
+ return cropped, bbox
150
+
151
+
152
+ # --------------------------------------------------------------------------
153
+ # Background normalization + gridline removal (CV2-based)
154
+ # --------------------------------------------------------------------------
155
+
156
+ def _ensure_grayscale(pil_image: Image.Image) -> np.ndarray:
157
+ """Return uint8 grayscale numpy array from any PIL image."""
158
+ if pil_image.mode != "L":
159
+ pil_image = pil_image.convert("L")
160
+ return np.asarray(pil_image, dtype=np.uint8)
161
+
162
+
163
+ def remove_gridlines_and_background(
164
+ pil_image: Image.Image,
165
+ background_stretch: bool = True,
166
+ remove_gridlines: bool = True,
167
+ grid_min_length_frac: float = 0.30,
168
+ soft_threshold: int = 245,
169
+ ) -> Tuple[Image.Image, Dict[str, object]]:
170
+ """Normalize background to white and (optionally) remove thin gridlines.
171
+
172
+ Pipeline:
173
+ 1. Convert to grayscale.
174
+ 2. (background_stretch) Linearly stretch the gray histogram so the
175
+ brightest pixel is 255 (cancels colored / off-white backgrounds).
176
+ 3. (remove_gridlines) Adaptive-threshold to a binary mask of dark
177
+ pixels (curve + axes + text + gridlines), then morphological
178
+ opening with very long horizontal `(1, K)` and vertical `(K, 1)`
179
+ kernels finds long thin lines; we inpaint those regions on the
180
+ grayscale image. The main curve survives because morphological
181
+ opening with a 1xK kernel only keeps strictly straight horizontal
182
+ runs of >=K dark pixels; a curving line breaks the connectivity.
183
+ 4. (soft_threshold) Push pixels >= `soft_threshold` to pure 255 to
184
+ snap any residual near-white background to clean white.
185
+
186
+ Falls back to a pure-PIL background stretch if cv2 is unavailable.
187
+
188
+ Returns:
189
+ (cleaned_pil, meta) where meta has keys was_stretched,
190
+ was_cleaned, n_horiz_gridlines, n_vert_gridlines.
191
+ """
192
+ meta: Dict[str, object] = {
193
+ "was_stretched": False,
194
+ "was_cleaned": False,
195
+ "n_horiz_gridlines": 0,
196
+ "n_vert_gridlines": 0,
197
+ }
198
+ arr = _ensure_grayscale(pil_image)
199
+
200
+ if background_stretch:
201
+ if arr.max() > 0:
202
+ scale = 255.0 / float(arr.max())
203
+ arr = np.clip(arr.astype(np.float32) * scale, 0, 255).astype(np.uint8)
204
+ meta["was_stretched"] = True
205
+
206
+ try:
207
+ import cv2
208
+ except ImportError:
209
+ if soft_threshold > 0:
210
+ arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
211
+ return Image.fromarray(arr, mode="L"), meta
212
+
213
+ if remove_gridlines:
214
+ H, W = arr.shape
215
+ binary = cv2.adaptiveThreshold(
216
+ arr, 255,
217
+ cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,
218
+ blockSize=31, C=10,
219
+ )
220
+ K_h = max(20, int(W * grid_min_length_frac))
221
+ K_v = max(20, int(H * grid_min_length_frac))
222
+
223
+ h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (K_h, 1))
224
+ h_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, h_kernel)
225
+ v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, K_v))
226
+ v_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, v_kernel)
227
+
228
+ meta["n_horiz_gridlines"] = int((h_lines.sum(axis=1) > 0).sum())
229
+ meta["n_vert_gridlines"] = int((v_lines.sum(axis=0) > 0).sum())
230
+
231
+ line_mask = cv2.bitwise_or(h_lines, v_lines)
232
+ if line_mask.sum() > 0:
233
+ line_mask = cv2.dilate(line_mask, np.ones((2, 2), np.uint8))
234
+ arr = cv2.inpaint(arr, line_mask, 3, cv2.INPAINT_TELEA)
235
+ meta["was_cleaned"] = True
236
+
237
+ if soft_threshold > 0:
238
+ arr = np.where(arr >= soft_threshold, 255, arr).astype(np.uint8)
239
+
240
+ return Image.fromarray(arr, mode="L"), meta
241
+
242
+
243
+ # --------------------------------------------------------------------------
244
+ # Orchestrator
245
+ # --------------------------------------------------------------------------
246
+
247
+ def prepare_for_image_mode(
248
+ pil_image: Image.Image,
249
+ do_crop: bool = True,
250
+ do_clean: bool = True,
251
+ target_size: int = 224,
252
+ ) -> Tuple[Image.Image, Dict[str, object]]:
253
+ """Full preprocessing pipeline for image-mode TRACE.
254
+
255
+ Steps (any can be skipped):
256
+ crop_to_plot_region -> remove_gridlines_and_background -> resize.
257
+
258
+ Args:
259
+ pil_image: any-mode PIL.Image.
260
+ do_crop: run OCR-based plot-region cropping.
261
+ do_clean: run background normalization + gridline removal.
262
+ target_size: output square edge length.
263
+
264
+ Returns:
265
+ (preprocessed_pil_L, meta) where meta is a flat dict suitable for
266
+ showing in the UI:
267
+ was_cropped: bool
268
+ crop_bbox: (l, t, r, b) or None
269
+ was_stretched: bool
270
+ was_cleaned: bool
271
+ n_horiz_gridlines: int
272
+ n_vert_gridlines: int
273
+ target_size: int
274
+ """
275
+ meta: Dict[str, object] = {
276
+ "was_cropped": False,
277
+ "crop_bbox": None,
278
+ "was_stretched": False,
279
+ "was_cleaned": False,
280
+ "n_horiz_gridlines": 0,
281
+ "n_vert_gridlines": 0,
282
+ "target_size": target_size,
283
+ }
284
+
285
+ img = pil_image
286
+ if do_crop:
287
+ cropped, bbox = crop_to_plot_region(img)
288
+ if bbox is not None:
289
+ img = cropped
290
+ meta["was_cropped"] = True
291
+ meta["crop_bbox"] = list(bbox)
292
+
293
+ if do_clean:
294
+ cleaned, clean_meta = remove_gridlines_and_background(img)
295
+ img = cleaned
296
+ meta["was_stretched"] = clean_meta["was_stretched"]
297
+ meta["was_cleaned"] = clean_meta["was_cleaned"]
298
+ meta["n_horiz_gridlines"] = clean_meta["n_horiz_gridlines"]
299
+ meta["n_vert_gridlines"] = clean_meta["n_vert_gridlines"]
300
+ else:
301
+ if img.mode != "L":
302
+ img = img.convert("L")
303
+
304
+ if img.size != (target_size, target_size):
305
+ img = img.resize((target_size, target_size), Image.BILINEAR)
306
+
307
+ return img, meta
308
+
309
+
310
+ __all__ = [
311
+ "crop_to_plot_region",
312
+ "remove_gridlines_and_background",
313
+ "prepare_for_image_mode",
314
+ ]