Instructions to use elephantmipt/test-sae with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use elephantmipt/test-sae with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="elephantmipt/test-sae", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("elephantmipt/test-sae", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers import PreTrainedModel | |
| from typing import Optional, Dict, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.autograd as autograd | |
| from copy import deepcopy | |
| from safetensors.torch import save_file, load_file | |
| from sae.modeling.config import SAEConfig | |
| import os | |
| class BaseSAE(PreTrainedModel): | |
| """Base class for autoencoder models.""" | |
| config_class = SAEConfig | |
| base_model_prefix = "sae" | |
| def __init__(self, config: SAEConfig): | |
| super().__init__(config) | |
| print(config) | |
| self.config = config | |
| torch.manual_seed(42) | |
| self.b_dec = nn.Parameter(torch.zeros(self.config.act_size)) | |
| self.b_enc = nn.Parameter(torch.zeros(self.config.dict_size)) | |
| self.W_enc = nn.Parameter( | |
| torch.nn.init.kaiming_uniform_( | |
| torch.empty(self.config.act_size, self.config.dict_size) | |
| ) | |
| ) | |
| self.W_dec = nn.Parameter( | |
| torch.nn.init.kaiming_uniform_( | |
| torch.empty(self.config.dict_size, self.config.act_size) | |
| ) | |
| ) | |
| self.W_dec.data[:] = self.W_enc.t().data | |
| self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) | |
| self.num_batches_not_active = torch.zeros((self.config.dict_size,)) | |
| self.to(self.config.get_torch_dtype(self.config.dtype)) | |
| def preprocess_input(self, x): | |
| x = x.to(self.config.get_torch_dtype(self.config.sae_dtype)) | |
| if self.config.input_unit_norm: | |
| x_mean = x.mean(dim=-1, keepdim=True) | |
| x = x - x_mean | |
| x_std = x.std(dim=-1, keepdim=True) | |
| x = x / (x_std + 1e-5) | |
| return x, x_mean, x_std | |
| else: | |
| return x, None, None | |
| def postprocess_output(self, x_reconstruct, x_mean, x_std): | |
| if self.config.input_unit_norm: | |
| x_reconstruct = x_reconstruct * x_std + x_mean | |
| return x_reconstruct | |
| def make_decoder_weights_and_grad_unit_norm(self): | |
| W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) | |
| W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum( | |
| -1, keepdim=True | |
| ) * W_dec_normed | |
| self.W_dec.grad -= W_dec_grad_proj | |
| self.W_dec.data = W_dec_normed | |
| def update_inactive_features(self, acts): | |
| self.num_batches_not_active += (acts.sum(0) == 0).float() | |
| self.num_batches_not_active[acts.sum(0) > 0] = 0 | |
| # @classmethod | |
| # def from_pretrained( | |
| # cls, | |
| # pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], | |
| # *model_args, | |
| # **kwargs | |
| # ) -> "BaseSAE": | |
| # config = kwargs.pop("config", None) | |
| # if config is None: | |
| # config = SAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| # model = cls(config) | |
| # model.load_state_dict( | |
| # load_file(os.path.join(pretrained_model_name_or_path, "model.safetensors")) | |
| # ) | |
| # return model | |
| # def save_pretrained( | |
| # self, | |
| # save_directory: Union[str, os.PathLike], | |
| # **kwargs | |
| # ): | |
| # os.makedirs(save_directory, exist_ok=True) | |
| # # Save the config | |
| # self.config.save_pretrained(save_directory) | |
| # # Save the model weights | |
| # save_file( | |
| # self.state_dict(), | |
| # os.path.join(save_directory, "model.safetensors") | |
| # ) | |
| class BatchTopKSAE(BaseSAE): | |
| def forward(self, x): | |
| x, x_mean, x_std = self.preprocess_input(x) | |
| x_cent = x - self.b_dec | |
| acts = F.relu(x_cent @ self.W_enc) | |
| acts_topk = torch.topk(acts.flatten(), self.config.top_k * x.shape[0], dim=-1) | |
| acts_topk = ( | |
| torch.zeros_like(acts.flatten()) | |
| .scatter(-1, acts_topk.indices, acts_topk.values) | |
| .reshape(acts.shape) | |
| ) | |
| x_reconstruct = acts_topk @ self.W_dec + self.b_dec | |
| self.update_inactive_features(acts_topk) | |
| output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) | |
| return output | |
| def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): | |
| l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() | |
| l1_norm = acts_topk.float().abs().sum(-1).mean() | |
| l1_loss = self.config.l1_coeff * l1_norm | |
| l0_norm = (acts_topk > 0).float().sum(-1).mean() | |
| aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) | |
| loss = l2_loss + aux_loss | |
| num_dead_features = ( | |
| self.num_batches_not_active > self.config.n_batches_to_dead | |
| ).sum() | |
| sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) | |
| per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() | |
| total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() | |
| explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() | |
| output = { | |
| "sae_out": sae_out, | |
| "feature_acts": acts_topk, | |
| "num_dead_features": num_dead_features, | |
| "loss": loss, | |
| "l1_loss": l1_loss, | |
| "l2_loss": l2_loss, | |
| "l0_norm": l0_norm, | |
| "l1_norm": l1_norm, | |
| "aux_loss": aux_loss, | |
| "explained_variance": explained_variance, | |
| "top_k": self.config.top_k | |
| } | |
| return output | |
| def get_auxiliary_loss(self, x, x_reconstruct, acts): | |
| dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead | |
| if dead_features.sum() > 0: | |
| residual = x.float() - x_reconstruct.float() | |
| acts_topk_aux = torch.topk( | |
| acts[:, dead_features], | |
| min(self.config.top_k_aux, dead_features.sum()), | |
| dim=-1, | |
| ) | |
| acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( | |
| -1, acts_topk_aux.indices, acts_topk_aux.values | |
| ) | |
| x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] | |
| l2_loss_aux = ( | |
| self.config.aux_penalty | |
| * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() | |
| ) | |
| return l2_loss_aux | |
| else: | |
| return torch.tensor(0, dtype=x.dtype, device=x.device) | |
| class TopKSAE(BaseSAE): | |
| def forward(self, x): | |
| x, x_mean, x_std = self.preprocess_input(x) | |
| x_cent = x - self.b_dec | |
| acts = F.relu(x_cent @ self.W_enc) | |
| acts_topk = torch.topk(acts, self.config.top_k, dim=-1) | |
| acts_topk = torch.zeros_like(acts).scatter( | |
| -1, acts_topk.indices, acts_topk.values | |
| ) | |
| x_reconstruct = acts_topk @ self.W_dec + self.b_dec | |
| self.update_inactive_features(acts_topk) | |
| output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std) | |
| return output | |
| def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std): | |
| l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() | |
| l1_norm = acts_topk.float().abs().sum(-1).mean() | |
| l1_loss = self.config.l1_coeff * l1_norm | |
| l0_norm = (acts_topk > 0).float().sum(-1).mean() | |
| aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts) | |
| loss = l2_loss + l1_loss + aux_loss | |
| num_dead_features = ( | |
| self.num_batches_not_active > self.config.n_batches_to_dead | |
| ).sum() | |
| sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) | |
| per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() | |
| total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() | |
| explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() | |
| output = { | |
| "sae_out": sae_out, | |
| "feature_acts": acts_topk, | |
| "num_dead_features": num_dead_features, | |
| "loss": loss, | |
| "l1_loss": l1_loss, | |
| "l2_loss": l2_loss, | |
| "l0_norm": l0_norm, | |
| "l1_norm": l1_norm, | |
| "explained_variance": explained_variance, | |
| "aux_loss": aux_loss, | |
| } | |
| return output | |
| def get_auxiliary_loss(self, x, x_reconstruct, acts): | |
| dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead | |
| if dead_features.sum() > 0: | |
| residual = x.float() - x_reconstruct.float() | |
| acts_topk_aux = torch.topk( | |
| acts[:, dead_features], | |
| min(self.config.top_k_aux, dead_features.sum()), | |
| dim=-1, | |
| ) | |
| acts_aux = torch.zeros_like(acts[:, dead_features]).scatter( | |
| -1, acts_topk_aux.indices, acts_topk_aux.values | |
| ) | |
| x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] | |
| l2_loss_aux = ( | |
| self.config.aux_penalty | |
| * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() | |
| ) | |
| return l2_loss_aux | |
| else: | |
| return torch.tensor(0, dtype=x.dtype, device=x.device) | |
| class VanillaSAE(BaseSAE): | |
| def forward(self, x): | |
| x, x_mean, x_std = self.preprocess_input(x) | |
| x_cent = x - self.b_dec | |
| acts = F.relu(x_cent @ self.W_enc + self.b_enc) | |
| x_reconstruct = acts @ self.W_dec + self.b_dec | |
| self.update_inactive_features(acts) | |
| output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std) | |
| return output | |
| def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): | |
| l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() | |
| l1_norm = acts.float().abs().sum(-1).mean() | |
| l1_loss = self.config.l1_coeff * l1_norm | |
| l0_norm = (acts > 0).float().sum(-1).mean() | |
| loss = l2_loss + l1_loss | |
| num_dead_features = ( | |
| self.num_batches_not_active > self.config.n_batches_to_dead | |
| ).sum() | |
| sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) | |
| per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() | |
| total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() | |
| explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() | |
| output = { | |
| "sae_out": sae_out, | |
| "feature_acts": acts, | |
| "num_dead_features": num_dead_features, | |
| "loss": loss, | |
| "l1_loss": l1_loss, | |
| "l2_loss": l2_loss, | |
| "l0_norm": l0_norm, | |
| "l1_norm": l1_norm, | |
| "explained_variance": explained_variance, | |
| } | |
| return output | |
| import torch | |
| import torch.nn as nn | |
| class RectangleFunction(autograd.Function): | |
| def forward(ctx, x): | |
| ctx.save_for_backward(x) | |
| return ((x > -0.5) & (x < 0.5)).float() | |
| def backward(ctx, grad_output): | |
| (x,) = ctx.saved_tensors | |
| grad_input = grad_output.clone() | |
| grad_input[(x <= -0.5) | (x >= 0.5)] = 0 | |
| return grad_input | |
| class JumpReLUFunction(autograd.Function): | |
| def forward(ctx, x, log_threshold, bandwidth): | |
| ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) | |
| threshold = torch.exp(log_threshold) | |
| return x * (x > threshold).float() | |
| def backward(ctx, grad_output): | |
| x, log_threshold, bandwidth_tensor = ctx.saved_tensors | |
| bandwidth = bandwidth_tensor.item() | |
| threshold = torch.exp(log_threshold) | |
| x_grad = (x > threshold).float() * grad_output | |
| threshold_grad = ( | |
| -(threshold / bandwidth) | |
| * RectangleFunction.apply((x - threshold) / bandwidth) | |
| * grad_output | |
| ) | |
| return x_grad, threshold_grad, None # None for bandwidth | |
| class JumpReLU(nn.Module): | |
| def __init__(self, feature_size, bandwidth, device='cpu'): | |
| super(JumpReLU, self).__init__() | |
| self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) | |
| self.bandwidth = bandwidth | |
| def forward(self, x): | |
| return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) | |
| class StepFunction(autograd.Function): | |
| def forward(ctx, x, log_threshold, bandwidth): | |
| ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) | |
| threshold = torch.exp(log_threshold) | |
| return (x > threshold).float() | |
| def backward(ctx, grad_output): | |
| x, log_threshold, bandwidth_tensor = ctx.saved_tensors | |
| bandwidth = bandwidth_tensor.item() | |
| threshold = torch.exp(log_threshold) | |
| x_grad = torch.zeros_like(x) | |
| threshold_grad = ( | |
| -(1.0 / bandwidth) | |
| * RectangleFunction.apply((x - threshold) / bandwidth) | |
| * grad_output | |
| ) | |
| return x_grad, threshold_grad, None # None for bandwidth | |
| class JumpReLUSAE(BaseSAE): | |
| def __init__(self, config: SAEConfig): | |
| super().__init__(config) | |
| self.jumprelu = JumpReLU( | |
| feature_size=config.dict_size, | |
| bandwidth=config.bandwidth, | |
| device=config.device if hasattr(config, 'device') else 'cpu' | |
| ) | |
| def forward(self, x, use_pre_enc_bias=False): | |
| x, x_mean, x_std = self.preprocess_input(x) | |
| if use_pre_enc_bias: | |
| x = x - self.b_dec | |
| pre_activations = torch.relu(x @ self.W_enc + self.b_enc) | |
| feature_magnitudes = self.jumprelu(pre_activations) | |
| x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec | |
| return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std) | |
| def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std): | |
| l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean() | |
| l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.config.bandwidth).sum(dim=-1).mean() | |
| l0_loss = self.config.l1_coeff * l0 | |
| l1_loss = l0_loss | |
| loss = l2_loss + l1_loss | |
| num_dead_features = ( | |
| self.num_batches_not_active > self.config.n_batches_to_dead | |
| ).sum() | |
| sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std) | |
| per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze() | |
| total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze() | |
| explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean() | |
| output = { | |
| "sae_out": sae_out, | |
| "feature_acts": acts, | |
| "num_dead_features": num_dead_features, | |
| "loss": loss, | |
| "l1_loss": l1_loss, | |
| "l2_loss": l2_loss, | |
| "l0_norm": l0, | |
| "l1_norm": l0, | |
| "explained_variance": explained_variance, | |
| } | |
| return output | |