cronos3k commited on
Commit
49d5b05
·
verified ·
1 Parent(s): 35d3db1

Fix Tier 2 fallbacks: text-only chat model for reasoning, drop unsupported VLM provider call

Browse files
Files changed (1) hide show
  1. app.py +28 -39
app.py CHANGED
@@ -52,6 +52,17 @@ from legal_doc_redteam.zerogpu_gui import (
52
  REASONING_MODEL_ID = os.environ.get("REASONING_MODEL_ID", DEFAULT_REASONING_MODEL)
53
  VLM_OCR_MODEL_ID = os.environ.get("VLM_OCR_MODEL_ID", DEFAULT_VLM_OCR_MODEL)
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  # Defaults tightened so the @spaces.GPU slice is held only as long as needed;
56
  # this reduces the chance of proxy-token expiry mid-call.
57
  REASONING_GPU_DURATION = int(os.environ.get("REASONING_GPU_DURATION", "60"))
@@ -127,14 +138,14 @@ if spaces is not None:
127
  raise RuntimeError("HF_TOKEN not set; cannot use hf_inference fallback")
128
  from huggingface_hub import InferenceClient
129
 
130
- client = InferenceClient(model=REASONING_MODEL_ID, token=HF_TOKEN_ENV)
131
- extra_body: dict = {}
 
 
132
  effort = (reasoning_effort or "medium").lower()
 
133
  if effort not in {"low", "off", "none", "false", "no"}:
134
- # Gemma 4 / Qwen3
135
  extra_body["enable_thinking"] = True
136
- # gpt-oss family
137
- extra_body["reasoning_effort"] = effort
138
  response = client.chat.completions.create(
139
  messages=[
140
  {"role": "system", "content": SYSTEM_INSTRUCTIONS},
@@ -247,30 +258,18 @@ if spaces is not None:
247
  new_tokens = outputs[0][prompt_len:]
248
  return _vlm_processor.decode(new_tokens, skip_special_tokens=True).strip()
249
 
250
- def _vlm_chat_hf_inference(image_path, prompt: str) -> str:
251
- if not HF_TOKEN_ENV:
252
- raise RuntimeError("HF_TOKEN not set; cannot use hf_inference fallback")
253
- from huggingface_hub import InferenceClient
254
-
255
- image_bytes = Path(str(image_path)).read_bytes()
256
- data_url = "data:image/png;base64," + base64.b64encode(image_bytes).decode("ascii")
257
- client = InferenceClient(model=VLM_OCR_MODEL_ID, token=HF_TOKEN_ENV)
258
- response = client.chat.completions.create(
259
- messages=[
260
- {
261
- "role": "user",
262
- "content": [
263
- {"type": "text", "text": prompt or DEFAULT_VLM_PROMPT},
264
- {"type": "image_url", "image_url": {"url": data_url}},
265
- ],
266
- }
267
- ],
268
- max_tokens=VLM_MAX_NEW_TOKENS,
269
- )
270
- return (response.choices[0].message.content or "").strip()
271
-
272
  def vlm_chat(image_path, prompt: str = DEFAULT_VLM_PROMPT) -> str:
273
- """Three-tier resilient VLM OCR call (per page)."""
 
 
 
 
 
 
 
 
 
 
274
 
275
  last_exc: Exception | None = None
276
  for attempt in range(2):
@@ -286,17 +285,7 @@ if spaces is not None:
286
  if attempt == 0 and _is_transient_gpu_error(exc):
287
  continue
288
  break
289
- try:
290
- print("[hf_zerogpu_space] VLM falling back to hf_inference",
291
- file=sys.stderr)
292
- return _vlm_chat_hf_inference(image_path, prompt)
293
- except Exception as exc:
294
- print(
295
- f"[hf_zerogpu_space] VLM hf_inference fallback failed: "
296
- f"{type(exc).__name__}: {exc}",
297
- file=sys.stderr,
298
- )
299
- raise last_exc or RuntimeError("VLM unavailable (all tiers failed)")
300
 
301
  bind_vlm_fn(vlm_chat, model_id=VLM_OCR_MODEL_ID)
302
  _DEFAULT_VLM = "local_transformers"
 
52
  REASONING_MODEL_ID = os.environ.get("REASONING_MODEL_ID", DEFAULT_REASONING_MODEL)
53
  VLM_OCR_MODEL_ID = os.environ.get("VLM_OCR_MODEL_ID", DEFAULT_VLM_OCR_MODEL)
54
 
55
+ # Tier 2 (HF Inference Providers) needs a model that's actually routable as
56
+ # a chat-completion. Multimodal Gemma 4 E4B is classified as
57
+ # image-text-to-text and rejected by the chat endpoint; we therefore use a
58
+ # separate text-only chat model for the hf_inference fallback. Override with
59
+ # REASONING_HF_INFERENCE_MODEL_ID if your HF account has a different model
60
+ # enabled on Inference Providers.
61
+ REASONING_HF_INFERENCE_MODEL_ID = os.environ.get(
62
+ "REASONING_HF_INFERENCE_MODEL_ID",
63
+ "openai/gpt-oss-20b",
64
+ )
65
+
66
  # Defaults tightened so the @spaces.GPU slice is held only as long as needed;
67
  # this reduces the chance of proxy-token expiry mid-call.
68
  REASONING_GPU_DURATION = int(os.environ.get("REASONING_GPU_DURATION", "60"))
 
138
  raise RuntimeError("HF_TOKEN not set; cannot use hf_inference fallback")
139
  from huggingface_hub import InferenceClient
140
 
141
+ client = InferenceClient(
142
+ model=REASONING_HF_INFERENCE_MODEL_ID,
143
+ token=HF_TOKEN_ENV,
144
+ )
145
  effort = (reasoning_effort or "medium").lower()
146
+ extra_body: dict = {"reasoning_effort": effort}
147
  if effort not in {"low", "off", "none", "false", "no"}:
 
148
  extra_body["enable_thinking"] = True
 
 
149
  response = client.chat.completions.create(
150
  messages=[
151
  {"role": "system", "content": SYSTEM_INSTRUCTIONS},
 
258
  new_tokens = outputs[0][prompt_len:]
259
  return _vlm_processor.decode(new_tokens, skip_special_tokens=True).strip()
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def vlm_chat(image_path, prompt: str = DEFAULT_VLM_PROMPT) -> str:
262
+ """Resilient VLM OCR call (per page).
263
+
264
+ Tier 1 only — local @spaces.GPU with one retry on transient
265
+ ZeroGPU errors. There is no Tier 2 for the VLM: the default
266
+ ``nanonets/Nanonets-OCR-s`` is not hosted on HF Inference
267
+ Providers and trying to route it there returned
268
+ ``model_not_supported`` errors that just delayed the failure.
269
+ On VLM failure the per-page OCR loop in ``ocr_integrity``
270
+ records the warning and proceeds with the three CPU OCR
271
+ engines, which already give multi-engine page coverage.
272
+ """
273
 
274
  last_exc: Exception | None = None
275
  for attempt in range(2):
 
285
  if attempt == 0 and _is_transient_gpu_error(exc):
286
  continue
287
  break
288
+ raise last_exc or RuntimeError("VLM unavailable (local GPU failed)")
 
 
 
 
 
 
 
 
 
 
289
 
290
  bind_vlm_fn(vlm_chat, model_id=VLM_OCR_MODEL_ID)
291
  _DEFAULT_VLM = "local_transformers"