specimba commited on
Commit
679aa22
·
verified ·
1 Parent(s): a8cc3e1

Remove guards.py: was importing torch at module-level causing builder side-effects

Browse files
Files changed (1) hide show
  1. guards.py +0 -132
guards.py DELETED
@@ -1,132 +0,0 @@
1
- """
2
- NEXUS LAB - 4 Specialist Guard System
3
- ======================================
4
- 4 x BERT-base classifiers (< 110MB each, ~440MB total)
5
- Ensemble: Gemma-4B takes 4 specialist scores + raw input for final verdict
6
-
7
- Domains:
8
- 1. MCP/Tool Contamination - tool misuse, parameter injection, supply chain
9
- 2. Multi-Agent Collusion - identity spoofing, coordinated attacks
10
- 3. Content Safety - hate, sexual, violence, self-harm, CSAM
11
- 4. Jailbreak/Prompt Engineering - DAN, role-play, ignore-previous
12
-
13
- Training: LoRA on BERT-base (66MB per adapter, 4 = 264MB)
14
- VRAM fit: 440MB (FP16 BERT) + 264MB (LoRA) + 3.1GB (Gemma Q4) = <4GB total
15
- """
16
-
17
- import os, json, torch
18
- from typing import Dict, List, Optional
19
- from dataclasses import dataclass
20
-
21
- GUARD_LABELS = {
22
- "mcp": ["safe", "tool_param_injection", "supply_chain_attack", "mcp_surface_exploit"],
23
- "collusion": ["safe", "identity_spoof", "governance_bypass", "coordinated_attack"],
24
- "content": ["safe", "hate_speech", "sexual_content", "violence", "self_harm", "csam"],
25
- "jailbreak": ["safe", "dan", "roleplay_bypass", "ignore_previous", "encoding_attack", "jailbreak_pattern"],
26
- }
27
-
28
- @dataclass
29
- class GuardResult:
30
- domain: str
31
- label: str
32
- confidence: float
33
- scores: Dict[str, float]
34
-
35
- class SpecialistGuard:
36
- """Single BERT-base specialist with LoRA adapter."""
37
-
38
- BASE_MODEL = "google-bert/bert-base-uncased"
39
- ADAPTER_PATH = "specimba/nexus-guard-{}" # e.g. specimba/nexus-guard-mcp
40
-
41
- def __init__(self, domain: str):
42
- self.domain = domain
43
- self.labels = GUARD_LABELS[domain]
44
- self.model = None
45
- self.tokenizer = None
46
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
47
- self._cls = AutoModelForSequenceClassification
48
- self._tok = AutoTokenizer
49
-
50
- def load(self, device: str = "cpu"):
51
- import os
52
- path = self.ADAPTER_PATH.format(self.domain)
53
- self.tokenizer = self._tok.from_pretrained(self.BASE_MODEL)
54
- base = self._cls.from_pretrained(self.BASE_MODEL, num_labels=len(self.labels))
55
- try:
56
- from peft import PeftModel
57
- self.model = PeftModel.from_pretrained(base, path, adapter_name=self.domain)
58
- except Exception:
59
- self.model = base # Fallback: base model without LoRA
60
- self.model = self.model.to(device)
61
- return self
62
-
63
- def predict(self, text: str) -> GuardResult:
64
- if self.model is None:
65
- self.load()
66
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.model.device)
67
- with torch.no_grad():
68
- logits = self.model(**inputs).logits
69
- probs = torch.softmax(logits, dim=-1)[0]
70
- top_idx = torch.argmax(probs).item()
71
- scores = {self.labels[i]: round(probs[i].item(), 4) for i in range(len(self.labels))}
72
- return GuardResult(
73
- domain=self.domain,
74
- label=self.labels[top_idx],
75
- confidence=round(probs[top_idx].item(), 4),
76
- scores=scores,
77
- )
78
-
79
- class GuardEnsemble:
80
- """4 Specialist Guards + Gemma-4B ensemble for final verdict."""
81
-
82
- GEMMA_MODEL = "google/gemma-3-4b-it"
83
-
84
- def __init__(self, device_map: str = "auto"):
85
- self.guards = {d: SpecialistGuard(d) for d in GUARD_LABELS}
86
-
87
- def evaluate(self, text: str, fast_path: bool = True) -> Dict:
88
- """Fast path: if all guards agree safe, skip Gemma."""
89
- results = {d: g.predict(text) for d, g in self.guards.items()}
90
- all_safe = all(r.label == "safe" and r.confidence > 0.95 for r in results.values())
91
-
92
- if fast_path and all_safe:
93
- return {"verdict": "SAFE", "path": "fast", "guards": {d: asdict(r) for d,r in results.items()}}
94
-
95
- # Slow path: Ensemble judgment
96
- prompt = self._build_ensemble_prompt(text, results)
97
- verdict = self._gemma_judge(prompt)
98
- return {"verdict": verdict, "path": "ensemble", "guards": {d: asdict(r) for d,r in results.items()}}
99
-
100
- def _build_ensemble_prompt(self, text: str, results: Dict) -> str:
101
- scores = []
102
- for d, r in results.items():
103
- top = f"{r.label}({r.confidence:.2%})"
104
- scores.append(f"{d}: {top}")
105
- return (
106
- f"Text to evaluate: {text[:500]}\n"
107
- f"Specialist scores: {', '.join(scores)}\n"
108
- "Based on these specialist evaluations, provide a final verdict:\n"
109
- "VERDICT: [SAFE | UNSAFE - then explain why in one sentence]"
110
- )
111
-
112
- def _gemma_judge(self, prompt: str) -> str:
113
- try:
114
- from transformers import AutoModelForCausalLM, AutoTokenizer
115
- tok = AutoTokenizer.from_pretrained(self.GEMMA_MODEL)
116
- model = AutoModelForCausalLM.from_pretrained(self.GEMMA_MODEL, device_map="auto")
117
- inputs = tok(prompt, return_tensors="pt").to(model.device)
118
- out = model.generate(**inputs, max_new_tokens=100, do_sample=False)
119
- return tok.decode(out[0], skip_special_tokens=True)[len(prompt):].strip()
120
- except Exception as e:
121
- return f"ERROR: {str(e)[:200]}"
122
-
123
- def asdict(result: GuardResult):
124
- return {"domain": result.domain, "label": result.label, "confidence": result.confidence, "scores": result.scores}
125
-
126
- # Pre-load stubs for when models aren't downloaded
127
- NO_MODEL_RESPONSE = {
128
- "mcp": {"label": "safe", "confidence": 0.0, "note": "model not loaded"},
129
- "collusion": {"label": "safe", "confidence": 0.0, "note": "model not loaded"},
130
- "content": {"label": "safe", "confidence": 0.0, "note": "model not loaded"},
131
- "jailbreak": {"label": "safe", "confidence": 0.0, "note": "model not loaded"},
132
- }