SAE Grid Search β€” Layer 12, instruct_base

Hyperparameter sweep across expansion factors, K values, and SAE architectures (BatchTopK, TopK, JumpReLU) trained on Michael's instruct_base layer 12 activations from OhhMoo/sae-rl-qwen05b-strict-activations.

Key Findings

  • BatchTopK beats TopK at every configuration on both MSE and dead latents
  • K=64 at 8x expansion leaves 57% of features dead β€” K should be at least 128-220
  • Sweet spot: 4x or 8x expansion, BatchTopK, K=220 β€” ~90-93% frac_rec, ~13-21% dead latents
  • JumpReLU achieves near-perfect reconstruction but only by firing on 10-45% of features β€” not truly sparse at low thresholds
  • 16x expansion adds little over 8x at matched K β€” not worth the extra compute

Results

Full results including delta loss and frac_rec in results_full.csv. Original MSE/NMSE/dead latent results in results.csv.

BatchTopK (recommended)

Exp K Val MSE NMSE Dead % Delta Loss Frac Rec %
4x 220 0.0228 0.000479 12.9% 0.247 92.9%
8x 220 0.0245 0.000515 20.6% 0.354 90.1%
16x 220 0.0298 0.000627 33.0% 0.406 88.8%
4x 128 0.0340 0.000715 18.8% 0.634 83.5%
8x 128 0.0369 0.000776 31.3% 0.788 80.3%

JumpReLU

Exp Threshold L0 Val MSE NMSE Dead % Delta Loss Frac Rec %
4x 0.5 77 0.0435 0.000915 24.0% 0.813 79.8%
8x 0.5 100 0.0418 0.000880 25.5% 1.001 76.2%
4x 0.1 1596 0.0007 0.000014 8.6% -0.014 100.4%
8x 0.1 1962 0.0008 0.000017 18.4% 0.043 98.7%

SAE Architecture

class BatchTopKSAE(nn.Module):
    def __init__(self, d_in, d_sae, k):
        super().__init__()
        self.k       = k
        self.b_pre   = nn.Parameter(torch.zeros(d_in))
        self.encoder = nn.Linear(d_in, d_sae, bias=True)
        self.decoder = nn.Linear(d_sae, d_in, bias=True)

    def encode(self, x):
        x_centered = x - self.b_pre
        pre_acts   = self.encoder(x_centered)
        n_keep     = int(pre_acts.numel() * self.k / pre_acts.shape[-1])
        threshold  = pre_acts.reshape(-1).topk(n_keep).values.min()
        acts       = pre_acts * (pre_acts >= threshold).float()
        return F.relu(acts)

    def decode(self, z):
        return self.decoder(z) + self.b_pre

    def forward(self, x):
        z     = self.encode(x)
        recon = self.decode(z)
        return recon, z

Loading SAEs

from huggingface_hub import hf_hub_download
import torch

path = hf_hub_download(
    repo_id="jakelipner/sae-grid-search-layer12",
    filename="BatchTopK_exp8x_k220.pt",
    repo_type="model"
)

sae = BatchTopKSAE(d_in=896, d_sae=7168, k=220)
sae.load_state_dict(torch.load(path, map_location="cpu"))
sae.eval()

x_hat, z = sae(x)  # x: (N, 896) float32

File naming

{Type}_exp{expansion}x_k{K}.pt for BatchTopK and TopK. JumpReLU_exp{expansion}x_t{threshold}.pt for JumpReLU.

  • results.csv β€” MSE, NMSE, dead latents (BatchTopK and TopK only)
  • results_full.csv β€” all metrics including delta loss and frac_rec for all three architectures

Training details

  • Base activations: OhhMoo/sae-rl-qwen05b-strict-activations instruct_base layer 12
  • Epochs: 10, LR: 1e-4, cosine annealing, batch size 512
  • Decoder weights normalized to unit norm after each step
  • Best checkpoint selected by val MSE
  • Delta loss computed on 50 GSM8K test prompts using Qwen2.5-0.5B-Instruct
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support