antebe1 commited on
Commit
ebdb5ae
·
verified ·
1 Parent(s): 15ccc67

Upload DFC CrossCoder model

Browse files
Files changed (10) hide show
  1. README.md +120 -0
  2. app.py +231 -0
  3. config.json +7 -0
  4. demo.py +100 -0
  5. dfc_crosscoder.py +201 -0
  6. inference_config.json +13 -0
  7. minimal_demo.py +439 -0
  8. model.pt +3 -0
  9. requirements.txt +4 -0
  10. space_requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: pytorch
6
+ tags:
7
+ - crosscoder
8
+ - sparse-autoencoder
9
+ - interpretability
10
+ - feature-extraction
11
+ - pytorch
12
+ datasets:
13
+ - fineweb
14
+ - toolrl
15
+ metrics:
16
+ - reconstruction_loss
17
+ - sparsity
18
+ base_model:
19
+ - chengq9/ToolRL-Qwen2.5-3B
20
+ - Qwen/Qwen2.5-3B
21
+ pipeline_tag: feature-extraction
22
+ ---
23
+
24
+ # DFC CrossCoder (antebe1/dfc-crosscoder-qwen-ToolRL)
25
+
26
+ A Dedicated Feature CrossCoder (DFC) trained to extract sparse, interpretable features from the activations of two related language models:
27
+ - **Model A (ToolRL)**: chengq9/ToolRL-Qwen2.5-3B
28
+ - **Model B (Base)**: Qwen/Qwen2.5-3B
29
+
30
+ The DFC learns to identify features that are:
31
+ - **A-exclusive**: Only active for the ToolRL model
32
+ - **B-exclusive**: Only active for the base model
33
+ - **Shared**: Active for both models
34
+
35
+ ## Model Details
36
+
37
+ ### Architecture
38
+ - **Dictionary Size**: 16,384 features
39
+ - **Top-K**: 90 active features per example
40
+ - **Layer**: 13 (of transformer)
41
+ - **Activation Dimension**: 2048
42
+
43
+ ### Feature Partitions
44
+ - **A-exclusive features**: 819 (5.0%)
45
+ - **B-exclusive features**: 819 (5.0%)
46
+ - **Shared features**: 14746 (90.0%)
47
+
48
+ ### Training Details
49
+ - **Training Steps**: 20,000
50
+ - **Learning Rate**: 0.0001
51
+ - **Batch Size**: 64
52
+ - **Sparsity Coefficient**: 0
53
+ - **Exclusive Sparsity Coefficient**: 0.001
54
+
55
+
56
+
57
+ ## Usage
58
+
59
+ ```python
60
+ from transformers import AutoTokenizer
61
+ from dfc_crosscoder import DFCCrossCoder, extract_activations
62
+
63
+ # Load the model
64
+ dfc = DFCCrossCoder.from_pretrained("antebe1/dfc-crosscoder-qwen-ToolRL")
65
+
66
+ # Load base models (you need both original models)
67
+ model_a = AutoModelForCausalLM.from_pretrained("chengq9/ToolRL-Qwen2.5-3B")
68
+ model_b = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B")
69
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
70
+
71
+ # Extract and encode features
72
+ text = "Your input text here"
73
+ activations = extract_activations(model_a, model_b, tokenizer, [text], layer=13)
74
+ features = dfc.encode(activations)
75
+
76
+ # Analyze features
77
+ active_features = (features > 0).nonzero(as_tuple=True)[1]
78
+ print(f"Active features: {active_features.tolist()}")
79
+
80
+ # Decode back to activations
81
+ reconstructed = dfc.decode(features)
82
+ ```
83
+
84
+ ## Model Files
85
+
86
+ - `model.pt` - PyTorch model weights
87
+ - `config.json` - Model configuration
88
+ - `dfc_crosscoder.py` - Model implementation
89
+ - `demo.py` - Minimal usage demo
90
+ - `requirements.txt` - Dependencies
91
+
92
+ ## Intended Use
93
+
94
+ This model is designed for:
95
+ - **Interpretability research**: Understanding differences between fine-tuned and base models
96
+ - **Feature analysis**: Identifying model-specific vs shared computational patterns
97
+ - **Steering experiments**: Modifying model behavior through feature manipulation
98
+ - **Mechanistic interpretability**: Studying how fine-tuning affects internal representations
99
+
100
+ ## Limitations
101
+
102
+ - Trained on specific model pair (chengq9/ToolRL-Qwen2.5-3B / Qwen/Qwen2.5-3B)
103
+ - Features are extracted from layer 13 only
104
+ - Requires both original models for activation extraction
105
+ - Performance depends on quality of training data and hyperparameters
106
+
107
+ ## Citation
108
+
109
+ ```bibtex
110
+ @misc{dfc_crosscoder_antebe1_dfc_crosscoder_qwen_ToolRL,
111
+ title={DFC CrossCoder: Sparse Feature Extraction for Model Comparison},
112
+ author={Your Name Here},
113
+ year={2026},
114
+ url={https://huggingface.co/antebe1/dfc-crosscoder-qwen-ToolRL}
115
+ }
116
+ ```
117
+
118
+ ## License
119
+
120
+ MIT License - see LICENSE file for details.
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — Hugging Face Space demo for DFC CrossCoder.
3
+
4
+ This file creates a Gradio demo that can be deployed to Hugging Face Spaces.
5
+ Upload this along with the model files to create a working demo.
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ import json
13
+
14
+
15
+ # Simplified DFC for Space demo
16
+ class DFCCrossCoder(torch.nn.Module):
17
+ def __init__(self, activation_dim: int, dict_size: int, k: int, n_a: int, n_b: int):
18
+ super().__init__()
19
+ self.activation_dim = activation_dim
20
+ self.dict_size = dict_size
21
+ self.k = k
22
+ self.n_a = n_a
23
+ self.n_b = n_b
24
+ self.n_shared = dict_size - n_a - n_b
25
+ self.a_end = n_a
26
+ self.b_end = n_a + n_b
27
+
28
+ self.W_enc = torch.nn.Parameter(torch.zeros(2, activation_dim, dict_size))
29
+ self.b_enc = torch.nn.Parameter(torch.zeros(dict_size))
30
+ self.W_dec = torch.nn.Parameter(torch.zeros(dict_size, 2, activation_dim))
31
+ self.b_dec = torch.nn.Parameter(torch.zeros(2, activation_dim))
32
+
33
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
34
+ pre = torch.einsum("bmd,mdf->bf", x, self.W_enc) + self.b_enc
35
+ pre = F.relu(pre)
36
+ topk_vals, topk_idx = torch.topk(pre, self.k, dim=-1)
37
+ features = torch.zeros_like(pre)
38
+ features.scatter_(-1, topk_idx, topk_vals)
39
+ return features
40
+
41
+ def decode(self, features: torch.Tensor) -> torch.Tensor:
42
+ return torch.einsum("bf,fmd->bmd", features, self.W_dec) + self.b_dec
43
+
44
+ @classmethod
45
+ def from_pretrained(cls, model_path: str = ".", device: str = "cpu"):
46
+ # Load config
47
+ with open(f"{model_path}/config.json") as f:
48
+ config = json.load(f)
49
+
50
+ model = cls(
51
+ activation_dim=config["activation_dim"],
52
+ dict_size=config["dict_size"],
53
+ k=config["k"],
54
+ n_a=config.get("n_a", int(config["dict_size"] * 0.05)),
55
+ n_b=config.get("n_b", int(config["dict_size"] * 0.05))
56
+ )
57
+
58
+ state_dict = torch.load(f"{model_path}/model.pt", map_location=device, weights_only=True)
59
+ model.load_state_dict(state_dict)
60
+ return model.to(device)
61
+
62
+
63
+ # Global variables for models (loaded once)
64
+ dfc_model = None
65
+ model_a = None
66
+ model_b = None
67
+ tokenizer = None
68
+
69
+ def load_models():
70
+ """Load all models once at startup."""
71
+ global dfc_model, model_a, model_b, tokenizer
72
+
73
+ if dfc_model is None:
74
+ print("Loading models...")
75
+
76
+ # Load DFC
77
+ dfc_model = DFCCrossCoder.from_pretrained(".", device="cpu")
78
+ dfc_model.eval()
79
+
80
+ # Load language models with reduced precision for space
81
+ model_a = AutoModelForCausalLM.from_pretrained(
82
+ "chengq9/ToolRL-Qwen2.5-3B",
83
+ torch_dtype=torch.float16, # Use half precision
84
+ device_map="auto",
85
+ low_cpu_mem_usage=True
86
+ )
87
+
88
+ model_b = AutoModelForCausalLM.from_pretrained(
89
+ "Qwen/Qwen2.5-3B",
90
+ torch_dtype=torch.float16,
91
+ device_map="auto",
92
+ low_cpu_mem_usage=True
93
+ )
94
+
95
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ print("Models loaded!")
100
+
101
+
102
+ def analyze_text(text: str) -> str:
103
+ """Analyze input text and return formatted results."""
104
+ if not text.strip():
105
+ return "⚠️ Please enter some text to analyze."
106
+
107
+ try:
108
+ load_models() # Ensure models are loaded
109
+
110
+ # Extract activations
111
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
112
+
113
+ with torch.no_grad():
114
+ # Get activations from both models
115
+ out_a = model_a(**inputs, output_hidden_states=True)
116
+ out_b = model_b(**inputs, output_hidden_states=True)
117
+
118
+ # Extract last token activations (layer 13)
119
+ layer_idx = 13
120
+ hidden_a = out_a.hidden_states[layer_idx + 1]
121
+ hidden_b = out_b.hidden_states[layer_idx + 1]
122
+
123
+ last_idx = inputs["attention_mask"].sum(dim=1) - 1
124
+ act_a = hidden_a[0, last_idx].cpu().float()
125
+ act_b = hidden_b[0, last_idx].cpu().float()
126
+
127
+ # Combine activations
128
+ activations = torch.stack([act_a, act_b], dim=0).unsqueeze(0) # (1, 2, d)
129
+
130
+ # Encode to features
131
+ features = dfc_model.encode(activations)
132
+ feature_vec = features[0]
133
+
134
+ # Find active features
135
+ active_indices = (feature_vec > 0).nonzero(as_tuple=True)[0]
136
+ active_values = feature_vec[active_indices]
137
+
138
+ if len(active_indices) == 0:
139
+ return "🤔 No active features found. Try a different text."
140
+
141
+ # Sort by strength
142
+ sorted_indices = torch.argsort(active_values, descending=True)
143
+ top_indices = active_indices[sorted_indices[:10]]
144
+ top_values = active_values[sorted_indices[:10]]
145
+
146
+ # Partition analysis
147
+ a_excl = sum(idx < dfc_model.a_end for idx in active_indices)
148
+ b_excl = sum(dfc_model.a_end <= idx < dfc_model.b_end for idx in active_indices)
149
+ shared = sum(idx >= dfc_model.b_end for idx in active_indices)
150
+
151
+ # Reconstruction quality
152
+ reconstructed = dfc_model.decode(features)
153
+ mse_loss = F.mse_loss(reconstructed, activations).item()
154
+
155
+ # Format results
156
+ result = f"""## 🔍 Analysis Results
157
+
158
+ **Input Text**: "{text}"
159
+
160
+ ### 📊 Feature Summary
161
+ - **Total Active Features**: {len(active_indices)}
162
+ - **Reconstruction Quality**: {mse_loss:.6f} MSE
163
+
164
+ ### 🏷️ Feature Distribution
165
+ - 🔴 **ToolRL-specific**: {a_excl} features ({a_excl/len(active_indices)*100:.1f}%)
166
+ - 🔵 **Base model-specific**: {b_excl} features ({b_excl/len(active_indices)*100:.1f}%)
167
+ - 🟢 **Shared features**: {shared} features ({shared/len(active_indices)*100:.1f}%)
168
+
169
+ ### ⭐ Top Active Features
170
+ """
171
+
172
+ for i, (idx, val) in enumerate(zip(top_indices, top_values)):
173
+ if idx < dfc_model.a_end:
174
+ partition = "🔴 ToolRL"
175
+ elif idx < dfc_model.b_end:
176
+ partition = "🔵 Base"
177
+ else:
178
+ partition = "🟢 Shared"
179
+
180
+ result += f"{i+1}. Feature {idx.item()} ({partition}) - **{val.item():.4f}**\n"
181
+
182
+ return result
183
+
184
+ except Exception as e:
185
+ return f"❌ Error during analysis: {str(e)}\n\nPlease try again with different text."
186
+
187
+
188
+ # Example texts for easy testing
189
+ example_texts = [
190
+ "To solve this problem, I need to use the calculator tool.",
191
+ "The weather is beautiful today.",
192
+ "Let me search for information about machine learning.",
193
+ "I should call the API to get the current data.",
194
+ "Python is a great programming language for data science."
195
+ ]
196
+
197
+ # Create Gradio interface
198
+ demo = gr.Interface(
199
+ fn=analyze_text,
200
+ inputs=gr.Textbox(
201
+ lines=3,
202
+ placeholder="Enter text to analyze...",
203
+ label="📝 Input Text",
204
+ info="Enter any text to see how features activate differently between ToolRL and base models"
205
+ ),
206
+ outputs=gr.Markdown(label="📊 Analysis Results"),
207
+ title="🧠 DFC CrossCoder Demo",
208
+ description="""
209
+ This demo analyzes text using a **DFC CrossCoder** to reveal how features activate differently between:
210
+ - 🔴 **ToolRL Model**: Fine-tuned for tool usage
211
+ - 🔵 **Base Model**: Original Qwen2.5-3B
212
+ - 🟢 **Shared Features**: Common to both models
213
+
214
+ The CrossCoder extracts sparse, interpretable features from the internal representations of both models.
215
+ """,
216
+ examples=[[text] for text in example_texts],
217
+ theme="soft",
218
+ allow_flagging="never"
219
+ )
220
+
221
+ if __name__ == "__main__":
222
+ # Load models at startup (for better UX)
223
+ print("🚀 Starting DFC CrossCoder demo...")
224
+ load_models()
225
+
226
+ # Launch the demo
227
+ demo.launch(
228
+ share=False, # Set to True for sharing
229
+ server_name="0.0.0.0", # For Spaces
230
+ server_port=7860 # Default Spaces port
231
+ )
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dim": 2048,
3
+ "dict_size": 16384,
4
+ "k": 90,
5
+ "n_a": 819,
6
+ "n_b": 819
7
+ }
demo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ demo.py — Minimal demo for DFC CrossCoder usage.
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import json
8
+
9
+
10
+ def extract_last_token_activations(model, tokenizer, texts, layer_idx, device="cuda:0"):
11
+ """Extract last-token activations from a model."""
12
+ model.eval()
13
+ all_acts = []
14
+
15
+ with torch.no_grad():
16
+ for text in texts:
17
+ # Tokenize
18
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
19
+ input_ids = inputs["input_ids"].to(device)
20
+ attention_mask = inputs["attention_mask"].to(device)
21
+
22
+ # Forward pass
23
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
24
+
25
+ # Get last token activation
26
+ hidden_states = outputs.hidden_states[layer_idx + 1] # +1 because [0] is embedding
27
+ last_idx = attention_mask.sum(dim=1) - 1
28
+ last_token_act = hidden_states[0, last_idx]
29
+ all_acts.append(last_token_act.cpu())
30
+
31
+ return torch.stack(all_acts)
32
+
33
+
34
+ def main():
35
+ """Demo usage of DFC CrossCoder."""
36
+
37
+ # Load the DFC model (replace with your repo name)
38
+ from dfc_crosscoder import DFCCrossCoder
39
+ dfc = DFCCrossCoder.from_pretrained("your-username/dfc-crosscoder")
40
+ dfc.eval()
41
+
42
+ # Load the original models (you need both)
43
+ print("Loading models...")
44
+ model_a = AutoModelForCausalLM.from_pretrained("chengq9/ToolRL-Qwen2.5-3B")
45
+ model_b = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B")
46
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
47
+
48
+ if tokenizer.pad_token is None:
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+
51
+ # Example text
52
+ texts = [
53
+ "To solve this problem, I need to use the calculator tool.",
54
+ "The weather is beautiful today.",
55
+ "Let me search for the latest news about AI research."
56
+ ]
57
+
58
+ print(f"Analyzing {len(texts)} texts...")
59
+
60
+ for i, text in enumerate(texts):
61
+ print(f"\n--- Text {i+1}: {text} ---")
62
+
63
+ # Extract activations from both models
64
+ act_a = extract_last_token_activations(model_a, tokenizer, [text], layer_idx=13)
65
+ act_b = extract_last_token_activations(model_b, tokenizer, [text], layer_idx=13)
66
+
67
+ # Combine activations
68
+ combined_acts = torch.stack([act_a[0], act_b[0]], dim=0).unsqueeze(0) # (1, 2, d)
69
+
70
+ # Encode to features
71
+ features = dfc.encode(combined_acts)
72
+
73
+ # Analyze
74
+ active_indices = (features[0] > 0).nonzero(as_tuple=True)[0]
75
+ active_values = features[0][active_indices]
76
+
77
+ # Sort by strength
78
+ sorted_indices = torch.argsort(active_values, descending=True)
79
+ top_features = active_indices[sorted_indices[:10]]
80
+ top_values = active_values[sorted_indices[:10]]
81
+
82
+ print(f"Active features: {len(active_indices)}")
83
+ print(f"Top 10 features: {top_features.tolist()}")
84
+ print(f"Values: {[f'{v:.3f}' for v in top_values.tolist()]}")
85
+
86
+ # Partition analysis
87
+ a_excl = sum(idx < dfc.a_end for idx in top_features)
88
+ b_excl = sum(dfc.a_end <= idx < dfc.b_end for idx in top_features)
89
+ shared = sum(idx >= dfc.b_end for idx in top_features)
90
+
91
+ print(f"Feature distribution: A-exclusive={a_excl}, B-exclusive={b_excl}, Shared={shared}")
92
+
93
+ # Decode features back to activations
94
+ reconstructed = dfc.decode(features)
95
+ mse_loss = torch.nn.functional.mse_loss(reconstructed, combined_acts)
96
+ print(f"Reconstruction MSE: {mse_loss.item():.6f}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
dfc_crosscoder.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dfc.py — Dedicated Feature CrossCoder (DFC) model.
3
+
4
+ Feature layout in dict_size
5
+ ────────────────────────────
6
+ ┌─────────────────────┬─────────────────────┬──────────────────────────┐
7
+ │ A-exclusive (n_a) │ B-exclusive (n_b) │ Shared (n_shared) │
8
+ └─────────────────────┴─────────────────────┴──────────────────────────┘
9
+ idx: 0 ─────── a_end ──────── b_end ───────────────────── dict_size
10
+
11
+ Constraints (enforced by gradient masking + _apply_masks every step)
12
+ ──────────────────────────────────────────────────────────────────────
13
+ • Model A cannot encode/decode B-exclusive features
14
+ • Model B cannot encode/decode A-exclusive features
15
+ • Shared features are accessible to both
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ from pathlib import Path
22
+
23
+ from bitsandbytes import features
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+
29
+ class DFCCrossCoder(nn.Module):
30
+
31
+ def __init__(
32
+ self,
33
+ activation_dim: int,
34
+ dict_size: int,
35
+ k: int,
36
+ model_a_exclusive_pct: float = 0.05,
37
+ model_b_exclusive_pct: float = 0.05,
38
+ ):
39
+ super().__init__()
40
+ self.activation_dim = activation_dim
41
+ self.dict_size = dict_size
42
+ self.k = k
43
+
44
+ self.n_a = int(dict_size * model_a_exclusive_pct)
45
+ self.n_b = int(dict_size * model_b_exclusive_pct)
46
+ self.n_shared = dict_size - self.n_a - self.n_b
47
+ self.a_end = self.n_a
48
+ self.b_end = self.n_a + self.n_b
49
+
50
+ print(
51
+ f"[DFC] dict={dict_size} k={k} | "
52
+ f"A-excl={self.n_a} B-excl={self.n_b} shared={self.n_shared}"
53
+ )
54
+
55
+ # Encoder: W_enc[model, d_in, dict_size]
56
+ self.W_enc = nn.Parameter(
57
+ torch.randn(2, activation_dim, dict_size) / (activation_dim ** 0.5)
58
+ )
59
+ self.b_enc = nn.Parameter(torch.zeros(dict_size))
60
+
61
+ # Decoder: W_dec[dict_size, model, d_in]
62
+ self.W_dec = nn.Parameter(
63
+ torch.randn(dict_size, 2, activation_dim) / (dict_size ** 0.5)
64
+ )
65
+ self.b_dec = nn.Parameter(torch.zeros(2, activation_dim))
66
+
67
+ # ── Partition masks (move with .to(device)) ───────────────────
68
+ # enc_mask[model, dict_size]
69
+ enc_mask = torch.ones(2, dict_size)
70
+ enc_mask[1, : self.a_end] = 0 # B cannot encode A-excl
71
+ enc_mask[0, self.a_end : self.b_end] = 0 # A cannot encode B-excl
72
+ self.register_buffer("enc_mask", enc_mask)
73
+
74
+ # dec_mask[dict_size, model]
75
+ dec_mask = torch.ones(dict_size, 2)
76
+ dec_mask[: self.a_end, 1] = 0 # A-excl: B decoder = 0
77
+ dec_mask[self.a_end : self.b_end, 0] = 0 # B-excl: A decoder = 0
78
+ self.register_buffer("dec_mask", dec_mask)
79
+
80
+ self._apply_masks()
81
+
82
+ # ── Weight enforcement ────────────────────────────────────────────
83
+
84
+ @torch.no_grad()
85
+ def _apply_masks(self):
86
+ """Zero forbidden weights. Call after every optimiser step."""
87
+ for m in range(2):
88
+ self.W_enc.data[m] *= self.enc_mask[m].unsqueeze(0)
89
+ for m in range(2):
90
+ self.W_dec.data[:, m, :] *= self.dec_mask[:, m].unsqueeze(1)
91
+
92
+ # ── Forward ───────────────────────────────────────────────────────
93
+
94
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
95
+ """x: (B, 2, d) → features: (B, dict_size) sparse top-k."""
96
+ W = self.W_enc * self.enc_mask.unsqueeze(1) # (2, d, dict)
97
+ pre = torch.einsum("bmd,mdf->bf", x, W) + self.b_enc
98
+ pre = F.relu(pre)
99
+ topk_vals, topk_idx = torch.topk(pre, self.k, dim=-1)
100
+ features = torch.zeros_like(pre)
101
+ features.scatter_(-1, topk_idx, topk_vals)
102
+ return features
103
+
104
+ def decode(self, features: torch.Tensor) -> torch.Tensor:
105
+ """features: (B, dict_size) → (B, 2, d)."""
106
+ W = self.W_dec * self.dec_mask.unsqueeze(-1) # (dict, 2, d)
107
+ return torch.einsum("bf,fmd->bmd", features, W) + self.b_dec
108
+
109
+ def forward(self, x: torch.Tensor):
110
+ """x: (B, 2, d) → (reconstruction, features)."""
111
+ features = self.encode(x)
112
+ recon = self.decode(features)
113
+ return recon, features
114
+
115
+ def loss(
116
+ self,
117
+ x: torch.Tensor,
118
+ sparsity_coef: float = 1e-3,
119
+ exclusive_sparsity_coef: float = 1e-3 # Lower penalty for exclusive features
120
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
121
+ """MSE + weighted L1 sparsity. Returns (total, mse, l1_shared, l1_exclusive)."""
122
+ recon, features = self.forward(x)
123
+ mse = F.mse_loss(recon, x)
124
+
125
+ # Split features by partition
126
+ # fa = features[:, :self.a_end] # A-exclusive
127
+ # fb = features[:, self.a_end:self.b_end] # B-exclusive
128
+ fs = features[:, self.b_end:] # Shared
129
+
130
+ # A sees: A-exclusive + shared
131
+ fa = torch.cat([features[:, :self.a_end], features[:, self.b_end:]], dim=-1) # A-exclusive + shared
132
+ fb = torch.cat([features[:, self.a_end:self.b_end], features[:, self.b_end:]], dim=-1) # B-exclusive + shared
133
+
134
+ # Separate sparsity penalties
135
+ l1_shared = fs.abs().mean()
136
+ l1_exclusive = (fa.abs().mean() + fb.abs().mean()) / 2
137
+ total = mse + exclusive_sparsity_coef * l1_exclusive + sparsity_coef * l1_shared
138
+
139
+ return total, mse, l1_shared, l1_exclusive
140
+
141
+ # ── Diagnostics ───────────────────────────────────────────────────
142
+
143
+ @torch.no_grad()
144
+ def verify_partition_integrity(self) -> dict[str, float]:
145
+ """Max absolute value in weights that should be zero."""
146
+ enc_viol = (self.W_enc.abs() * (1 - self.enc_mask).unsqueeze(1)).max().item()
147
+ dec_viol_a = self.W_dec[: self.a_end, 1, :].abs().max().item()
148
+ dec_viol_b = self.W_dec[self.a_end : self.b_end, 0, :].abs().max().item()
149
+ return {
150
+ "enc_max_violation": enc_viol,
151
+ "dec_max_violation": max(dec_viol_a, dec_viol_b),
152
+ }
153
+
154
+ @torch.no_grad()
155
+ def feature_stats(self, features: torch.Tensor) -> dict[str, float]:
156
+ """Partition-level activation stats for a batch of features."""
157
+ fa = features[:, : self.a_end]
158
+ fb = features[:, self.a_end : self.b_end]
159
+ fs = features[:, self.b_end :]
160
+ return {
161
+ "l0_total": (features > 0).float().sum(dim=-1).mean().item(),
162
+ "l0_a_excl": (fa > 0).float().sum(dim=-1).mean().item(),
163
+ "l0_b_excl": (fb > 0).float().sum(dim=-1).mean().item(),
164
+ "l0_shared": (fs > 0).float().sum(dim=-1).mean().item(),
165
+ "mean_a_excl": fa.mean().item(),
166
+ "mean_b_excl": fb.mean().item(),
167
+ "mean_shared": fs.mean().item(),
168
+ }
169
+
170
+ # ── Save / Load ───────────────────────────────────────────────────
171
+
172
+ def save(self, path: str) -> None:
173
+ Path(path).mkdir(parents=True, exist_ok=True)
174
+ torch.save(self.state_dict(), f"{path}/model.pt")
175
+ json.dump(
176
+ dict(
177
+ activation_dim=self.activation_dim,
178
+ dict_size=self.dict_size,
179
+ k=self.k,
180
+ n_a=self.n_a,
181
+ n_b=self.n_b,
182
+ ),
183
+ open(f"{path}/config.json", "w"),
184
+ indent=2,
185
+ )
186
+ print(f"[DFC] Saved → {path}")
187
+
188
+ @classmethod
189
+ def load(cls, path: str, device: str = "cpu") -> "DFCCrossCoder":
190
+ cfg = json.load(open(f"{path}/config.json"))
191
+ model = cls(
192
+ activation_dim=cfg["activation_dim"],
193
+ dict_size=cfg["dict_size"],
194
+ k=cfg["k"],
195
+ model_a_exclusive_pct=cfg["n_a"] / cfg["dict_size"],
196
+ model_b_exclusive_pct=cfg["n_b"] / cfg["dict_size"],
197
+ )
198
+ model.load_state_dict(
199
+ torch.load(f"{path}/model.pt", map_location=device, weights_only=True)
200
+ )
201
+ return model.to(device)
inference_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "dfc_crosscoder",
3
+ "model_a_name": "chengq9/ToolRL-Qwen2.5-3B",
4
+ "model_b_name": "Qwen/Qwen2.5-3B",
5
+ "tokenizer_name": "Qwen/Qwen2.5-3B",
6
+ "layer": 13,
7
+ "activation_dim": 2048,
8
+ "dict_size": 16384,
9
+ "k": 90,
10
+ "n_a": 819,
11
+ "n_b": 819,
12
+ "n_shared": 14746
13
+ }
minimal_demo.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ minimal_demo.py — Standalone minimal demo for DFC CrossCoder.
3
+
4
+ A lightweight demonstration of the DFC CrossCoder that can run as:
5
+ 1. Command-line demo
6
+ 2. Gradio web interface
7
+ 3. Hugging Face Space
8
+
9
+ Usage:
10
+ python minimal_demo.py --text "Your input text"
11
+ python minimal_demo.py --gradio # Start web interface
12
+ python minimal_demo.py --interface # Interactive CLI
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import sys
18
+ from typing import List, Dict, Tuple, Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+
24
+
25
+ # Simplified DFC implementation for demo (copy of key parts)
26
+ class SimpleDFCCrossCoder(torch.nn.Module):
27
+ """Simplified DFC CrossCoder for demo purposes."""
28
+
29
+ def __init__(self, activation_dim: int, dict_size: int, k: int, n_a: int, n_b: int):
30
+ super().__init__()
31
+ self.activation_dim = activation_dim
32
+ self.dict_size = dict_size
33
+ self.k = k
34
+ self.n_a = n_a
35
+ self.n_b = n_b
36
+ self.n_shared = dict_size - n_a - n_b
37
+ self.a_end = n_a
38
+ self.b_end = n_a + n_b
39
+
40
+ # Model weights (will be loaded from checkpoint)
41
+ self.W_enc = torch.nn.Parameter(torch.zeros(2, activation_dim, dict_size))
42
+ self.b_enc = torch.nn.Parameter(torch.zeros(dict_size))
43
+ self.W_dec = torch.nn.Parameter(torch.zeros(dict_size, 2, activation_dim))
44
+ self.b_dec = torch.nn.Parameter(torch.zeros(2, activation_dim))
45
+
46
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
47
+ """Encode activations to sparse features."""
48
+ pre = torch.einsum("bmd,mdf->bf", x, self.W_enc) + self.b_enc
49
+ pre = F.relu(pre)
50
+ topk_vals, topk_idx = torch.topk(pre, self.k, dim=-1)
51
+ features = torch.zeros_like(pre)
52
+ features.scatter_(-1, topk_idx, topk_vals)
53
+ return features
54
+
55
+ def decode(self, features: torch.Tensor) -> torch.Tensor:
56
+ """Decode features back to activations."""
57
+ return torch.einsum("bf,fmd->bmd", features, self.W_dec) + self.b_dec
58
+
59
+ @classmethod
60
+ def from_pretrained(cls, model_path: str, device: str = "cpu"):
61
+ """Load model from checkpoint."""
62
+ # Load config
63
+ import json
64
+ with open(f"{model_path}/config.json") as f:
65
+ config = json.load(f)
66
+
67
+ # Create model
68
+ model = cls(
69
+ activation_dim=config["activation_dim"],
70
+ dict_size=config["dict_size"],
71
+ k=config["k"],
72
+ n_a=config.get("n_a", int(config["dict_size"] * 0.05)),
73
+ n_b=config.get("n_b", int(config["dict_size"] * 0.05))
74
+ )
75
+
76
+ # Load weights
77
+ state_dict = torch.load(f"{model_path}/model.pt", map_location=device, weights_only=True)
78
+ model.load_state_dict(state_dict)
79
+ return model.to(device)
80
+
81
+
82
+ class DFCDemo:
83
+ """Demo class for DFC CrossCoder functionality."""
84
+
85
+ def __init__(
86
+ self,
87
+ dfc_path: str = "./checkpoints/dfc2",
88
+ model_a_name: str = "chengq9/ToolRL-Qwen2.5-3B",
89
+ model_b_name: str = "Qwen/Qwen2.5-3B",
90
+ layer: int = 13,
91
+ device: str = "auto"
92
+ ):
93
+ self.dfc_path = dfc_path
94
+ self.model_a_name = model_a_name
95
+ self.model_b_name = model_b_name
96
+ self.layer = layer
97
+
98
+ # Auto-detect device
99
+ if device == "auto":
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ self.device = device
102
+
103
+ # Models (loaded on first use)
104
+ self._dfc = None
105
+ self._model_a = None
106
+ self._model_b = None
107
+ self._tokenizer = None
108
+
109
+ @property
110
+ def dfc(self):
111
+ """Lazy load DFC model."""
112
+ if self._dfc is None:
113
+ print("Loading DFC CrossCoder...")
114
+ self._dfc = SimpleDFCCrossCoder.from_pretrained(self.dfc_path, device=self.device)
115
+ self._dfc.eval()
116
+ return self._dfc
117
+
118
+ @property
119
+ def models(self):
120
+ """Lazy load language models."""
121
+ if self._model_a is None:
122
+ print("Loading language models...")
123
+ print(f" Model A: {self.model_a_name}")
124
+ self._model_a = AutoModelForCausalLM.from_pretrained(
125
+ self.model_a_name,
126
+ torch_dtype=torch.float32,
127
+ device_map=None
128
+ ).to(self.device).eval()
129
+
130
+ print(f" Model B: {self.model_b_name}")
131
+ self._model_b = AutoModelForCausalLM.from_pretrained(
132
+ self.model_b_name,
133
+ torch_dtype=torch.float32,
134
+ device_map=None
135
+ ).to(self.device).eval()
136
+
137
+ print(" Tokenizer...")
138
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_b_name)
139
+ if self._tokenizer.pad_token is None:
140
+ self._tokenizer.pad_token = self._tokenizer.eos_token
141
+ self._tokenizer.padding_side = "left"
142
+
143
+ return self._model_a, self._model_b, self._tokenizer
144
+
145
+ def extract_activations(self, texts: List[str]) -> torch.Tensor:
146
+ """Extract last-token activations from both models."""
147
+ model_a, model_b, tokenizer = self.models
148
+
149
+ # Tokenize
150
+ inputs = tokenizer(
151
+ texts,
152
+ return_tensors="pt",
153
+ padding=True,
154
+ truncation=True,
155
+ max_length=512
156
+ )
157
+ input_ids = inputs["input_ids"].to(self.device)
158
+ attention_mask = inputs["attention_mask"].to(self.device)
159
+
160
+ activations = []
161
+
162
+ with torch.no_grad():
163
+ # Model A
164
+ out_a = model_a(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
165
+ hidden_a = out_a.hidden_states[self.layer + 1]
166
+ last_idx = attention_mask.sum(dim=1) - 1
167
+ act_a = hidden_a[torch.arange(len(texts)), last_idx]
168
+
169
+ # Model B
170
+ out_b = model_b(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
171
+ hidden_b = out_b.hidden_states[self.layer + 1]
172
+ act_b = hidden_b[torch.arange(len(texts)), last_idx]
173
+
174
+ # Stack as (batch, models, hidden_dim)
175
+ activations = torch.stack([act_a, act_b], dim=1)
176
+
177
+ return activations
178
+
179
+ def analyze_text(self, text: str) -> Dict:
180
+ """Analyze a single text and return feature breakdown."""
181
+ # Extract activations
182
+ activations = self.extract_activations([text])
183
+
184
+ # Encode to features
185
+ features = self.dfc.encode(activations)
186
+ feature_vec = features[0] # Single text
187
+
188
+ # Find active features
189
+ active_indices = (feature_vec > 0).nonzero(as_tuple=True)[0]
190
+ active_values = feature_vec[active_indices]
191
+
192
+ # Sort by activation strength
193
+ sorted_indices = torch.argsort(active_values, descending=True)
194
+ top_indices = active_indices[sorted_indices[:20]] # Top 20
195
+ top_values = active_values[sorted_indices[:20]]
196
+
197
+ # Partition analysis
198
+ a_excl_count = sum(idx < self.dfc.a_end for idx in active_indices)
199
+ b_excl_count = sum(self.dfc.a_end <= idx < self.dfc.b_end for idx in active_indices)
200
+ shared_count = sum(idx >= self.dfc.b_end for idx in active_indices)
201
+
202
+ # Reconstruction quality
203
+ reconstructed = self.dfc.decode(features)
204
+ mse_loss = F.mse_loss(reconstructed, activations).item()
205
+
206
+ return {
207
+ "text": text,
208
+ "total_active_features": len(active_indices),
209
+ "top_features": [
210
+ {"index": idx.item(), "value": val.item(), "partition": self._get_partition_name(idx.item())}
211
+ for idx, val in zip(top_indices, top_values)
212
+ ],
213
+ "partition_counts": {
214
+ "A_exclusive": a_excl_count,
215
+ "B_exclusive": b_excl_count,
216
+ "Shared": shared_count
217
+ },
218
+ "reconstruction_mse": mse_loss,
219
+ "model_info": {
220
+ "dict_size": self.dfc.dict_size,
221
+ "k": self.dfc.k,
222
+ "layer": self.layer,
223
+ "model_a": self.model_a_name,
224
+ "model_b": self.model_b_name
225
+ }
226
+ }
227
+
228
+ def _get_partition_name(self, feature_idx: int) -> str:
229
+ """Get partition name for a feature index."""
230
+ if feature_idx < self.dfc.a_end:
231
+ return "A-exclusive"
232
+ elif feature_idx < self.dfc.b_end:
233
+ return "B-exclusive"
234
+ else:
235
+ return "Shared"
236
+
237
+ def compare_texts(self, texts: List[str]) -> List[Dict]:
238
+ """Compare multiple texts."""
239
+ return [self.analyze_text(text) for text in texts]
240
+
241
+
242
+ def print_analysis(analysis: Dict):
243
+ """Print analysis results in a nice format."""
244
+ print(f"\n{'='*60}")
245
+ print(f"TEXT: {analysis['text']}")
246
+ print(f"{'='*60}")
247
+
248
+ print(f"Active Features: {analysis['total_active_features']}")
249
+ print(f"Reconstruction MSE: {analysis['reconstruction_mse']:.6f}")
250
+
251
+ print(f"\nPartition Distribution:")
252
+ for partition, count in analysis['partition_counts'].items():
253
+ percentage = count / analysis['total_active_features'] * 100 if analysis['total_active_features'] > 0 else 0
254
+ print(f" {partition}: {count} ({percentage:.1f}%)")
255
+
256
+ print(f"\nTop Active Features:")
257
+ for i, feat in enumerate(analysis['top_features'][:10]):
258
+ print(f" {i+1:2d}. Feature {feat['index']:5d} | {feat['partition']:12s} | Value: {feat['value']:.4f}")
259
+
260
+
261
+ def create_gradio_interface(demo: DFCDemo):
262
+ """Create Gradio web interface."""
263
+ try:
264
+ import gradio as gr
265
+ except ImportError:
266
+ raise ImportError("Please install gradio: pip install gradio")
267
+
268
+ def analyze_interface(text):
269
+ """Gradio interface function."""
270
+ if not text.strip():
271
+ return "Please enter some text to analyze."
272
+
273
+ try:
274
+ analysis = demo.analyze_text(text.strip())
275
+
276
+ # Format results
277
+ result = f"""
278
+ ## Analysis Results
279
+
280
+ **Text**: {analysis['text']}
281
+
282
+ **Active Features**: {analysis['total_active_features']}
283
+ **Reconstruction MSE**: {analysis['reconstruction_mse']:.6f}
284
+
285
+ ### Partition Distribution
286
+ - **A-exclusive** (ToolRL): {analysis['partition_counts']['A_exclusive']} features
287
+ - **B-exclusive** (Base): {analysis['partition_counts']['B_exclusive']} features
288
+ - **Shared**: {analysis['partition_counts']['Shared']} features
289
+
290
+ ### Top Active Features
291
+ """
292
+
293
+ for i, feat in enumerate(analysis['top_features'][:10]):
294
+ result += f"{i+1}. Feature {feat['index']} ({feat['partition']}) - Value: {feat['value']:.4f}\n"
295
+
296
+ return result
297
+
298
+ except Exception as e:
299
+ return f"Error: {str(e)}"
300
+
301
+ # Create interface
302
+ iface = gr.Interface(
303
+ fn=analyze_interface,
304
+ inputs=gr.Textbox(
305
+ lines=3,
306
+ placeholder="Enter text to analyze (e.g., 'To solve this problem, I need to use the calculator tool.')",
307
+ label="Input Text"
308
+ ),
309
+ outputs=gr.Markdown(label="Analysis Results"),
310
+ title="DFC CrossCoder Demo",
311
+ description="Analyze text using the DFC CrossCoder to see which features are active and how they're distributed between ToolRL and Base models.",
312
+ examples=[
313
+ ["To solve this problem, I need to use the calculator tool."],
314
+ ["The weather is beautiful today."],
315
+ ["Let me search for information about machine learning."],
316
+ ["I should use the weather API to get current conditions."],
317
+ ["Python is a great programming language for data science."]
318
+ ]
319
+ )
320
+
321
+ return iface
322
+
323
+
324
+ def interactive_cli(demo: DFCDemo):
325
+ """Interactive command-line interface."""
326
+ print("\n" + "="*60)
327
+ print("DFC CrossCoder Interactive Demo")
328
+ print("="*60)
329
+ print("Commands:")
330
+ print(" analyze <text> - Analyze single text")
331
+ print(" compare <text1> | <text2> | <text3> - Compare multiple texts")
332
+ print(" help - Show this help")
333
+ print(" quit - Exit")
334
+ print("="*60)
335
+
336
+ while True:
337
+ try:
338
+ user_input = input("\n> ").strip()
339
+ if not user_input:
340
+ continue
341
+
342
+ if user_input.lower() in ["quit", "q", "exit"]:
343
+ print("Goodbye!")
344
+ break
345
+ elif user_input.lower() in ["help", "h"]:
346
+ print("\nCommands:")
347
+ print(" analyze <text> - Analyze single text")
348
+ print(" compare <text1> | <text2> | <text3> - Compare multiple texts")
349
+ print(" help - Show this help")
350
+ print(" quit - Exit")
351
+ elif user_input.startswith("analyze "):
352
+ text = user_input[8:].strip()
353
+ if text:
354
+ analysis = demo.analyze_text(text)
355
+ print_analysis(analysis)
356
+ else:
357
+ print("Please provide text to analyze.")
358
+ elif user_input.startswith("compare "):
359
+ texts_str = user_input[8:].strip()
360
+ texts = [t.strip() for t in texts_str.split("|") if t.strip()]
361
+ if len(texts) < 2:
362
+ print("Please provide at least 2 texts separated by |")
363
+ else:
364
+ analyses = demo.compare_texts(texts)
365
+ for analysis in analyses:
366
+ print_analysis(analysis)
367
+ else:
368
+ print("Unknown command. Type 'help' for available commands.")
369
+
370
+ except KeyboardInterrupt:
371
+ print("\nGoodbye!")
372
+ break
373
+ except Exception as e:
374
+ print(f"Error: {e}")
375
+
376
+
377
+ def main():
378
+ parser = argparse.ArgumentParser(description="DFC CrossCoder Demo")
379
+ parser.add_argument("--text", type=str, help="Text to analyze")
380
+ parser.add_argument("--checkpoint", default="./checkpoints/dfc2", help="Path to DFC checkpoint")
381
+ parser.add_argument("--gradio", action="store_true", help="Launch Gradio web interface")
382
+ parser.add_argument("--interface", action="store_true", help="Interactive CLI mode")
383
+ parser.add_argument("--device", default="auto", help="Device (cuda/cpu/auto)")
384
+ parser.add_argument("--compare", nargs="+", help="Compare multiple texts")
385
+
386
+ args = parser.parse_args()
387
+
388
+ # Create demo
389
+ demo = DFCDemo(
390
+ dfc_path=args.checkpoint,
391
+ device=args.device
392
+ )
393
+
394
+ try:
395
+ if args.gradio:
396
+ # Launch Gradio interface
397
+ iface = create_gradio_interface(demo)
398
+ iface.launch(share=True)
399
+ elif args.interface:
400
+ # Interactive CLI
401
+ interactive_cli(demo)
402
+ elif args.text:
403
+ # Single text analysis
404
+ analysis = demo.analyze_text(args.text)
405
+ print_analysis(analysis)
406
+ elif args.compare:
407
+ # Compare multiple texts
408
+ analyses = demo.compare_texts(args.compare)
409
+ for analysis in analyses:
410
+ print_analysis(analysis)
411
+ else:
412
+ # Default examples
413
+ print("DFC CrossCoder Demo - Running example analyses...")
414
+
415
+ example_texts = [
416
+ "To solve this problem, I need to use the calculator tool.",
417
+ "The weather is beautiful today.",
418
+ "Let me search for the latest research papers.",
419
+ "I should call the weather API to get current conditions."
420
+ ]
421
+
422
+ analyses = demo.compare_texts(example_texts)
423
+ for analysis in analyses:
424
+ print_analysis(analysis)
425
+
426
+ print(f"\n{'='*60}")
427
+ print("Demo completed! Try:")
428
+ print(" python minimal_demo.py --gradio # Web interface")
429
+ print(" python minimal_demo.py --interface # Interactive CLI")
430
+ print(" python minimal_demo.py --text 'Your text here'")
431
+ print("="*60)
432
+
433
+ except Exception as e:
434
+ print(f"Error: {e}")
435
+ sys.exit(1)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a343c4f59a9937d4cb01c1870729aa94e733ddc0f59555c77d52944fddb7d93
3
+ size 537217597
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.20.0
3
+ numpy
4
+ tqdm
space_requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.20.0
3
+ gradio>=3.0.0
4
+ numpy
5
+ tqdm
6
+ spaces