neuron_explainer/models/autoencoder_context.py (218 lines of code) (raw):

import os from dataclasses import dataclass, field from typing import Union import torch from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.file_utils import copy_to_local_cache, file_exists from neuron_explainer.models import Autoencoder from neuron_explainer.models.model_component_registry import Dimension, LayerIndex, NodeType @dataclass(frozen=True) class AutoencoderSpec: """Parameters used in the construction of an AutoencoderConfig object. Seperate so we don't need to validate when constructed""" dst: DerivedScalarType autoencoder_path_by_layer_index: dict[LayerIndex, str] @dataclass(frozen=True) class AutoencoderConfig: """ This class specifies a set of autoencoders to load from disk, for one or more layer indices. The activation location type indicates the type of activation that the autoencoder was trained on, and that will be fed into the autoencoder. """ dst: DerivedScalarType autoencoder_path_by_layer_index: dict[LayerIndex, str] def __post_init__(self) -> None: assert len(self.autoencoder_path_by_layer_index) > 0 if len(self.autoencoder_path_by_layer_index) > 1: assert ( None not in self.autoencoder_path_by_layer_index.keys() ), "layer_indices must be [None], or a list of int layer indices" @classmethod def from_spec(cls, params: AutoencoderSpec) -> "AutoencoderConfig": return cls( dst=params.dst, autoencoder_path_by_layer_index=params.autoencoder_path_by_layer_index, ) @dataclass(frozen=True) class AutoencoderContext: autoencoder_config: AutoencoderConfig device: torch.device _cached_autoencoders_by_path: dict[str, Autoencoder] = field(default_factory=dict) omit_dead_latents: bool = False """ Omit dead latents to save memory. Only happens if self.warmup() is called. Because we omit the same number of latents from all autoencoders, we can only omit the smallest number of dead latents among all autoencoders. """ @property def num_autoencoder_directions(self) -> int: """Note that this property might change after warmup() is called, if omit_dead_latents is True.""" if len(self._cached_autoencoders_by_path) == 0: raise ValueError( "num_autoencoder_directions is not populated yet. Call warmup() first." ) else: # all autoencoders have the same number of directions, so we can just check one first_autoencoder = next(iter(self._cached_autoencoders_by_path.values())) return first_autoencoder.latent_bias.shape[0] @property def _min_n_dead_latents(self) -> int: return min( count_dead_latents(autoencoder) for autoencoder in self._cached_autoencoders_by_path.values() ) @property def dst(self) -> DerivedScalarType: return self.autoencoder_config.dst @property def layer_indices(self) -> set[LayerIndex]: return set(self.autoencoder_config.autoencoder_path_by_layer_index.keys()) def get_autoencoder(self, layer_index: LayerIndex) -> Autoencoder: autoencoder_azure_path = self.autoencoder_config.autoencoder_path_by_layer_index.get( layer_index ) if autoencoder_azure_path is None: raise ValueError(f"No autoencoder path for layer_index {layer_index}") else: if autoencoder_azure_path in self._cached_autoencoders_by_path: autoencoder = self._cached_autoencoders_by_path[autoencoder_azure_path] else: # Check if the autoencoder is cached on disk disk_cache_path = os.path.join( "/tmp", autoencoder_azure_path.replace("https://", "") ) if file_exists(disk_cache_path): print(f"Loading autoencoder from disk cache: {disk_cache_path}") else: print(f"Reading autoencoder from blob storage: {autoencoder_azure_path}") copy_to_local_cache(autoencoder_azure_path, disk_cache_path) state_dict = torch.load(disk_cache_path, map_location=self.device) # released autoencoders are saved as a dict for better compatibility assert isinstance(state_dict, dict) autoencoder = Autoencoder.from_state_dict(state_dict, strict=False).to(self.device) self._cached_autoencoders_by_path[autoencoder_azure_path] = autoencoder # freeze the autoencoder for p in autoencoder.parameters(): p.requires_grad = False return autoencoder def warmup(self) -> None: """Load all autoencoders into memory.""" for layer_index in self.layer_indices: self.get_autoencoder(layer_index) # num_autoencoder_directions is always populated after warmup n_latents = self.num_autoencoder_directions if self.omit_dead_latents: # drop the dead latents to save memory, but keep the same number of directions for all autoencoders if self._min_n_dead_latents > 0: print(f"Omitting {self._min_n_dead_latents} dead latents from all autoencoders") n_latents_to_keep = n_latents - self._min_n_dead_latents for key, autoencoder in self._cached_autoencoders_by_path.items(): self._cached_autoencoders_by_path[key] = omit_least_active_latents( autoencoder, n_latents_to_keep=n_latents_to_keep ) def get_parameterized_dimension_sizes(self) -> dict[Dimension, int]: """A dictionary specifying the size of the parameterized dimensions; for convenient use with ScalarDerivers""" return { Dimension.AUTOENCODER_LATENTS: self.num_autoencoder_directions, } @property def autoencoder_node_type(self) -> NodeType | None: return _autoencoder_node_type_by_input_dst.get(self.dst) _autoencoder_node_type_by_input_dst = { # add more mappings as needed DerivedScalarType.MLP_POST_ACT: NodeType.MLP_AUTOENCODER_LATENT, DerivedScalarType.RESID_DELTA_MLP_FROM_MLP_POST_ACT: NodeType.MLP_AUTOENCODER_LATENT, DerivedScalarType.RESID_DELTA_MLP: NodeType.MLP_AUTOENCODER_LATENT, DerivedScalarType.RESID_DELTA_ATTN: NodeType.ATTENTION_AUTOENCODER_LATENT, DerivedScalarType.ATTN_WRITE: NodeType.ATTENTION_AUTOENCODER_LATENT, } @dataclass(frozen=True) class MultiAutoencoderContext: autoencoder_context_by_node_type: dict[NodeType, AutoencoderContext] @classmethod def from_context_or_multi_context( cls, input: Union[AutoencoderContext, "MultiAutoencoderContext", None], ) -> Union["MultiAutoencoderContext", None]: if isinstance(input, AutoencoderContext): return cls.from_autoencoder_context_list([input]) elif input is None: return None else: return input @classmethod def from_autoencoder_context_list( cls, autoencoder_context_list: list[AutoencoderContext] ) -> "MultiAutoencoderContext": # check if there are duplicate node types node_types = [ _autoencoder_node_type_by_input_dst[autoencoder_context.dst] for autoencoder_context in autoencoder_context_list ] if len(node_types) != len(set(node_types)): raise ValueError(f"Cannot load two autoencoders with the same node type ({node_types})") return cls( autoencoder_context_by_node_type={ _autoencoder_node_type_by_input_dst[autoencoder_context.dst]: autoencoder_context for autoencoder_context in autoencoder_context_list } ) def get_autoencoder_context( self, node_type: NodeType | None = None ) -> AutoencoderContext | None: if node_type is None or node_type == NodeType.AUTOENCODER_LATENT: # handle default case return self.get_single_autoencoder_context() else: return self.autoencoder_context_by_node_type.get(node_type, None) @property def has_single_autoencoder_context(self) -> bool: return len(self.autoencoder_context_by_node_type) == 1 def get_single_autoencoder_context(self) -> AutoencoderContext: assert self.has_single_autoencoder_context return next(iter(self.autoencoder_context_by_node_type.values())) def get_autoencoder( self, layer_index: LayerIndex, node_type: NodeType | None = None ) -> Autoencoder: autoencoder_context = self.get_autoencoder_context(node_type) assert autoencoder_context is not None return autoencoder_context.get_autoencoder(layer_index) def warmup(self) -> None: """Load all autoencoders into memory.""" for node_type, autoencoder_context in self.autoencoder_context_by_node_type.items(): print(f"Warming up autoencoder {node_type}") autoencoder_context.warmup() def get_decoder_weight(autoencoder: Autoencoder) -> torch.Tensor: return autoencoder.decoder.weight.T # shape (n_latents, d_ff) def get_autoencoder_output_weight_by_layer_index( autoencoder_context: AutoencoderContext, ) -> dict[LayerIndex, torch.Tensor]: return { layer_index: get_decoder_weight( autoencoder_context.get_autoencoder(layer_index) ) # shape (n_latents, d_ff) for layer_index in autoencoder_context.layer_indices } ACTIVATION_FREQUENCY_THRESHOLD_FOR_DEAD_LATENTS = 1e-8 def count_dead_latents(autoencoder: Autoencoder) -> int: if hasattr(autoencoder, "latents_activation_frequency"): if torch.all(autoencoder.latents_activation_frequency == 0): raise ValueError("latents_activation_frequency is all zeros, all latents are dead.") dead_latents_mask = ( autoencoder.latents_activation_frequency < ACTIVATION_FREQUENCY_THRESHOLD_FOR_DEAD_LATENTS ) num_dead_latents = int(dead_latents_mask.sum().item()) return num_dead_latents else: return 0 def omit_least_active_latents( autoencoder: Autoencoder, n_latents_to_keep: int, # if preserve_indices=True, ignore the stored activation frequencies, and keep the first indices. # this is to preserve latent indices compared to the original autoencoder. preserve_indices: bool = True, ) -> Autoencoder: n_latents_original = int(autoencoder.latent_bias.shape[0]) if n_latents_to_keep >= n_latents_original: return autoencoder device: torch.device = autoencoder.encoder.weight.device # create the dead latent mask (True for live latents, False for dead latents) mask = torch.ones(n_latents_original, dtype=torch.bool, device=device) if preserve_indices or not hasattr(autoencoder, "latents_activation_frequency"): # drop the last latents mask[n_latents_to_keep:] = 0 else: # drop the least active latents order = torch.argsort(autoencoder.latents_activation_frequency, descending=True) mask[order[n_latents_to_keep:]] = 0 # apply the mask to a new autoencoder n_latents = int(mask.sum().item()) d_model = autoencoder.pre_bias.shape[0] new_autoencoder = Autoencoder(n_latents, d_model).to(device) new_autoencoder.encoder.weight.data = autoencoder.encoder.weight[mask, :].clone() new_autoencoder.decoder.weight.data = autoencoder.decoder.weight[:, mask].clone() new_autoencoder.latent_bias.data = autoencoder.latent_bias[mask].clone() new_autoencoder.stats_last_nonzero.data = autoencoder.stats_last_nonzero[mask].clone() if hasattr(autoencoder, "latents_mean_square"): new_autoencoder.register_buffer( "latents_mean_square", torch.zeros(n_latents, dtype=torch.float) ) new_autoencoder.latents_mean_square.data = autoencoder.latents_mean_square[mask].clone() if hasattr(autoencoder, "latents_activation_frequency"): new_autoencoder.register_buffer( "latents_activation_frequency", torch.ones(n_latents, dtype=torch.float) ) new_autoencoder.latents_activation_frequency.data = ( autoencoder.latents_activation_frequency[mask].clone() ) return new_autoencoder