SAE Grid Search β Layer 12, instruct_base
Hyperparameter sweep across expansion factors, K values, and SAE architectures (BatchTopK, TopK, JumpReLU) trained on Michael's instruct_base layer 12 activations from OhhMoo/sae-rl-qwen05b-strict-activations.
Key Findings
- BatchTopK beats TopK at every configuration on both MSE and dead latents
- K=64 at 8x expansion leaves 57% of features dead β K should be at least 128-220
- Sweet spot: 4x or 8x expansion, BatchTopK, K=220 β ~90-93% frac_rec, ~13-21% dead latents
- JumpReLU achieves near-perfect reconstruction but only by firing on 10-45% of features β not truly sparse at low thresholds
- 16x expansion adds little over 8x at matched K β not worth the extra compute
Results
Full results including delta loss and frac_rec in results_full.csv. Original MSE/NMSE/dead latent results in results.csv.
BatchTopK (recommended)
| Exp | K | Val MSE | NMSE | Dead % | Delta Loss | Frac Rec % |
|---|---|---|---|---|---|---|
| 4x | 220 | 0.0228 | 0.000479 | 12.9% | 0.247 | 92.9% |
| 8x | 220 | 0.0245 | 0.000515 | 20.6% | 0.354 | 90.1% |
| 16x | 220 | 0.0298 | 0.000627 | 33.0% | 0.406 | 88.8% |
| 4x | 128 | 0.0340 | 0.000715 | 18.8% | 0.634 | 83.5% |
| 8x | 128 | 0.0369 | 0.000776 | 31.3% | 0.788 | 80.3% |
JumpReLU
| Exp | Threshold | L0 | Val MSE | NMSE | Dead % | Delta Loss | Frac Rec % |
|---|---|---|---|---|---|---|---|
| 4x | 0.5 | 77 | 0.0435 | 0.000915 | 24.0% | 0.813 | 79.8% |
| 8x | 0.5 | 100 | 0.0418 | 0.000880 | 25.5% | 1.001 | 76.2% |
| 4x | 0.1 | 1596 | 0.0007 | 0.000014 | 8.6% | -0.014 | 100.4% |
| 8x | 0.1 | 1962 | 0.0008 | 0.000017 | 18.4% | 0.043 | 98.7% |
SAE Architecture
class BatchTopKSAE(nn.Module):
def __init__(self, d_in, d_sae, k):
super().__init__()
self.k = k
self.b_pre = nn.Parameter(torch.zeros(d_in))
self.encoder = nn.Linear(d_in, d_sae, bias=True)
self.decoder = nn.Linear(d_sae, d_in, bias=True)
def encode(self, x):
x_centered = x - self.b_pre
pre_acts = self.encoder(x_centered)
n_keep = int(pre_acts.numel() * self.k / pre_acts.shape[-1])
threshold = pre_acts.reshape(-1).topk(n_keep).values.min()
acts = pre_acts * (pre_acts >= threshold).float()
return F.relu(acts)
def decode(self, z):
return self.decoder(z) + self.b_pre
def forward(self, x):
z = self.encode(x)
recon = self.decode(z)
return recon, z
Loading SAEs
from huggingface_hub import hf_hub_download
import torch
path = hf_hub_download(
repo_id="jakelipner/sae-grid-search-layer12",
filename="BatchTopK_exp8x_k220.pt",
repo_type="model"
)
sae = BatchTopKSAE(d_in=896, d_sae=7168, k=220)
sae.load_state_dict(torch.load(path, map_location="cpu"))
sae.eval()
x_hat, z = sae(x) # x: (N, 896) float32
File naming
{Type}_exp{expansion}x_k{K}.pt for BatchTopK and TopK.
JumpReLU_exp{expansion}x_t{threshold}.pt for JumpReLU.
results.csvβ MSE, NMSE, dead latents (BatchTopK and TopK only)results_full.csvβ all metrics including delta loss and frac_rec for all three architectures
Training details
- Base activations:
OhhMoo/sae-rl-qwen05b-strict-activationsinstruct_base layer 12 - Epochs: 10, LR: 1e-4, cosine annealing, batch size 512
- Decoder weights normalized to unit norm after each step
- Best checkpoint selected by val MSE
- Delta loss computed on 50 GSM8K test prompts using Qwen2.5-0.5B-Instruct
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support