neuron_explainer/activations/derived_scalars/write_tensors.py (82 lines of code) (raw):

import torch from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.models.autoencoder_context import ( AutoencoderContext, get_autoencoder_output_weight_by_layer_index, ) from neuron_explainer.models.model_component_registry import ( LayerIndex, NodeType, WeightLocationType, ) from neuron_explainer.models.model_context import ModelContext def get_attn_write_tensor_by_layer_index( model_context: ModelContext, layer_indices: list[int] | None, ) -> dict[LayerIndex, torch.Tensor]: """Returns a dictionary mapping layer index to the write weight matrix for that layer.""" if layer_indices is None: layer_indices = list(range(model_context.n_layers)) W_out_by_layer_index: dict[LayerIndex, torch.Tensor] = { layer_index: model_context.get_weight( location_type=WeightLocationType.ATTN_TO_RESIDUAL, layer=layer_index, device=model_context.device, ) # shape (n_heads, d_head, d_model) for layer_index in layer_indices } return W_out_by_layer_index def get_mlp_write_tensor_by_layer_index( model_context: ModelContext, layer_indices: list[int] | None ) -> dict[LayerIndex, torch.Tensor]: if layer_indices is None: layer_indices = list(range(model_context.n_layers)) W_out_location_type = WeightLocationType.MLP_TO_RESIDUAL W_out_by_layer_index: dict[LayerIndex, torch.Tensor] = { layer_index: model_context.get_weight( location_type=W_out_location_type, layer=layer_index, device=model_context.device, ) # shape (d_ff, d_model) for layer_index in layer_indices } return W_out_by_layer_index def _assert_non_none(x: LayerIndex) -> int: assert x is not None return x def get_autoencoder_write_tensor_by_layer_index( autoencoder_context: AutoencoderContext, model_context: ModelContext, ) -> dict[LayerIndex, torch.Tensor]: if autoencoder_context.dst == DerivedScalarType.MLP_POST_ACT: autoencoder_output_weight_by_layer_index = get_autoencoder_output_weight_by_layer_index( autoencoder_context ) W_out_by_layer_index = get_mlp_write_tensor_by_layer_index_with_autoencoder_context( autoencoder_context, model_context ) return { _assert_non_none(layer_index): torch.einsum( "an,nd->ad", autoencoder_output_weight_by_layer_index[layer_index], W_out_by_layer_index[_assert_non_none(layer_index)], ) for layer_index in autoencoder_context.layer_indices } else: assert ( autoencoder_context.dst.node_type == NodeType.RESIDUAL_STREAM_CHANNEL ), autoencoder_context.dst return get_autoencoder_output_weight_by_layer_index(autoencoder_context) def get_mlp_write_tensor_by_layer_index_with_autoencoder_context( autoencoder_context: AutoencoderContext, model_context: ModelContext, ) -> dict[int, torch.Tensor]: assert all(layer_index is not None for layer_index in autoencoder_context.layer_indices) layer_indices: list[int] = list(autoencoder_context.layer_indices) # type: ignore write_tensor_by_layer_index = get_mlp_write_tensor_by_layer_index( model_context=model_context, layer_indices=layer_indices ) return { _assert_non_none(layer_index): write_tensor_by_layer_index[layer_index] for layer_index in autoencoder_context.layer_indices }