neuron_explainer/models/autoencoder.py (105 lines of code) (raw):

from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F class Autoencoder(nn.Module): """Sparse autoencoder Implements: latents = activation(encoder(x - pre_bias) + latent_bias) recons = decoder(latents) + pre_bias """ def __init__( self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False ) -> None: """ :param n_latents: dimension of the autoencoder latent :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units) :param activation: activation function :param tied: whether to tie the encoder and decoder weights """ super().__init__() self.pre_bias = nn.Parameter(torch.zeros(n_inputs)) self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False) self.latent_bias = nn.Parameter(torch.zeros(n_latents)) self.activation = activation if tied: self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder) else: self.decoder = nn.Linear(n_latents, n_inputs, bias=False) self.stats_last_nonzero: torch.Tensor self.latents_activation_frequency: torch.Tensor self.latents_mean_square: torch.Tensor self.register_buffer("stats_last_nonzero", torch.zeros(n_latents, dtype=torch.long)) self.register_buffer( "latents_activation_frequency", torch.ones(n_latents, dtype=torch.float) ) self.register_buffer("latents_mean_square", torch.zeros(n_latents, dtype=torch.float)) def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor: """ :param x: input data (shape: [batch, n_inputs]) :param latent_slice: slice of latents to compute Example: latent_slice = slice(0, 10) to compute only the first 10 latents. :return: autoencoder latents before activation (shape: [batch, n_latents]) """ x = x - self.pre_bias latents_pre_act = F.linear( x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice] ) return latents_pre_act def encode(self, x: torch.Tensor) -> torch.Tensor: """ :param x: input data (shape: [batch, n_inputs]) :return: autoencoder latents (shape: [batch, n_latents]) """ return self.activation(self.encode_pre_act(x)) def decode(self, latents: torch.Tensor) -> torch.Tensor: """ :param latents: autoencoder latents (shape: [batch, n_latents]) :return: reconstructed data (shape: [batch, n_inputs]) """ return self.decoder(latents) + self.pre_bias def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ :param x: input data (shape: [batch, n_inputs]) :return: autoencoder latents pre activation (shape: [batch, n_latents]) autoencoder latents (shape: [batch, n_latents]) reconstructed data (shape: [batch, n_inputs]) """ latents_pre_act = self.encode_pre_act(x) latents = self.activation(latents_pre_act) recons = self.decode(latents) # set all indices of self.stats_last_nonzero where (latents != 0) to 0 self.stats_last_nonzero *= (latents == 0).all(dim=0).long() self.stats_last_nonzero += 1 return latents_pre_act, latents, recons @classmethod def from_state_dict( cls, state_dict: dict[str, torch.Tensor], strict: bool = True ) -> "Autoencoder": n_latents, d_model = state_dict["encoder.weight"].shape autoencoder = cls(n_latents, d_model) # Retrieve activation activation_class_name = state_dict.pop("activation", "ReLU") activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU) activation_state_dict = state_dict.pop("activation_state_dict", {}) if hasattr(activation_class, "from_state_dict"): autoencoder.activation = activation_class.from_state_dict( activation_state_dict, strict=strict ) else: autoencoder.activation = activation_class() if hasattr(autoencoder.activation, "load_state_dict"): autoencoder.activation.load_state_dict(activation_state_dict, strict=strict) # Load remaining state dict autoencoder.load_state_dict(state_dict, strict=strict) return autoencoder def state_dict(self, destination=None, prefix="", keep_vars=False): sd = super().state_dict(destination, prefix, keep_vars) sd[prefix + "activation"] = self.activation.__class__.__name__ if hasattr(self.activation, "state_dict"): sd[prefix + "activation_state_dict"] = self.activation.state_dict() return sd class TiedTranspose(nn.Module): def __init__(self, linear: nn.Linear): super().__init__() self.linear = linear def forward(self, x: torch.Tensor) -> torch.Tensor: assert self.linear.bias is None return F.linear(x, self.linear.weight.t(), None) @property def weight(self) -> torch.Tensor: return self.linear.weight.t() @property def bias(self) -> torch.Tensor: return self.linear.bias class TopK(nn.Module): def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None: super().__init__() self.k = k self.postact_fn = postact_fn def forward(self, x: torch.Tensor) -> torch.Tensor: topk = torch.topk(x, k=self.k, dim=-1) values = self.postact_fn(topk.values) # make all other values 0 result = torch.zeros_like(x) result.scatter_(-1, topk.indices, values) return result def state_dict(self, destination=None, prefix="", keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__}) return state_dict @classmethod def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK": k = state_dict["k"] postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]() return cls(k=k, postact_fn=postact_fn) ACTIVATIONS_CLASSES = { "ReLU": nn.ReLU, "Identity": nn.Identity, "TopK": TopK, }