neuron_explainer/activations/derived_scalars/node_write.py (95 lines of code) (raw):

"""This file defines ScalarDerivers for a single node's residual stream write vector.""" import dataclasses import torch from neuron_explainer.activations.derived_scalars.attention import ( make_attn_weighted_value_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.autoencoder import ( make_online_autoencoder_latent_scalar_deriver_factory, ) from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import ( make_python_slice_from_tensor_indices, ) from neuron_explainer.activations.derived_scalars.locations import ConstantLayerIndexer from neuron_explainer.activations.derived_scalars.mlp import get_base_mlp_scalar_deriver from neuron_explainer.activations.derived_scalars.postprocessing import ResidualWriteConverter from neuron_explainer.activations.derived_scalars.scalar_deriver import ( DerivedScalarSource, DstConfig, ScalarDeriver, ) from neuron_explainer.models.autoencoder_context import MultiAutoencoderContext from neuron_explainer.models.model_component_registry import NodeType, PassType def make_node_write_scalar_deriver( dst_config: DstConfig, ) -> ScalarDeriver: """Returns a scalar deriver for the write vector from some upstream node type (MLP, autoencoder, or attention head).""" node_index = dst_config.node_index_for_attribution assert node_index is not None assert node_index.layer_index is not None model_context = dst_config.get_model_context() autoencoder_context = dst_config.get_autoencoder_context() multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context( autoencoder_context ) residual_write_converter = ResidualWriteConverter( model_context=model_context, multi_autoencoder_context=multi_autoencoder_context, ) # though called a Postprocessor, this converter is being used as part of the computation of this DST # It knows how to generate a residual stream write vector for a single node, and skips out on generating # residual stream write vectors for the entire layer worth of nodes, which is a much bigger/unnecessary matmul. dst_config_for_attribution = dataclasses.replace( dst_config, layer_indices=[node_index.layer_index], ) match node_index.node_type: case NodeType.ATTENTION_HEAD: activation_scalar_deriver = make_attn_weighted_value_scalar_deriver( dst_config=dst_config_for_attribution, ) case NodeType.MLP_NEURON: activation_scalar_deriver = get_base_mlp_scalar_deriver( dst_config=dst_config_for_attribution, ) case ( NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT ): activation_scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory( node_index.node_type )(dst_config_for_attribution) ds_index = residual_write_converter.convert_node_index_to_ds_index(node_index) sequence_token_index = ds_index.tensor_indices[0] slices_for_ds_index = make_python_slice_from_tensor_indices(ds_index.tensor_indices) def convert_activations_to_node_write( activations: torch.Tensor, layer_index: int | None, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD assert layer_index == node_index.layer_index single_node_write = residual_write_converter.postprocess_tensor( ds_index, activations[slices_for_ds_index], ) num_sequence_tokens = activations.shape[0] single_node_write_with_zeros = torch.zeros( (num_sequence_tokens,) + single_node_write.shape, device=single_node_write.device ) single_node_write_with_zeros[sequence_token_index, :] = single_node_write return single_node_write_with_zeros return activation_scalar_deriver.apply_layerwise_transform_fn_to_output( convert_activations_to_node_write, pass_type_to_transform=PassType.FORWARD, output_dst=DerivedScalarType.SINGLE_NODE_WRITE, ) def make_node_write_scalar_source( dst_config: DstConfig, ) -> DerivedScalarSource: assert dst_config.node_index_for_attribution is not None layer_index = dst_config.node_index_for_attribution.layer_index assert layer_index is not None node_write_scalar_deriver = make_node_write_scalar_deriver(dst_config) return DerivedScalarSource( scalar_deriver=node_write_scalar_deriver, pass_type=PassType.FORWARD, layer_indexer=ConstantLayerIndexer(layer_index), )