neuron_explainer/activations/derived_scalars/locations.py (188 lines of code) (raw):

""" This file contains code related to specifying the locations of derived scalars, and their inputs, within the residual stream. """ from abc import ABC, abstractmethod from typing import Literal, Sequence from neuron_explainer.activations.derived_scalars.config import DstConfig from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import ActivationIndex from neuron_explainer.models.model_component_registry import ( ActivationLocationType, LayerIndex, LocationWithinLayer, NodeType, PassType, ) class LayerIndexer(ABC): """A LayerIndexer is a function that maps from a list of indices in an original ActivationsAndMetadata object and a list of indices in a reindexed ActivationsAndMetadata object. It can do things like: - replace the activations at every layer with a reference to the activations at a single layer - replace the activations at every layer with a reference to a single activation from an ActivationLocationType that doesn't use layers (e.g. residual stream post embedding) - replace the activations at every layer with a reference to the activations one layer earlier - keep the activations at every layer the same DST computation typically acts on the activations at the same layer for each of several layers in an ActivationsAndMetadata object. When the computation requires activations from multiple distinct layers to compute the result for a given layer, this class handles the remapping so that downstream code can act on each layer independently.""" @abstractmethod def __call__(self, layer_indices: list[LayerIndex]) -> Sequence[LayerIndex | Literal["Dummy"]]: # given a list of layer indices to an original ActivationsAndMetadata object, return a list of layer indices # with which to index the activations_by_layer_index of the original object in order to obtain the reindexed # activations_by_layer_index of the new object # int refers to a normal layer index # None refers to an activation with no layer index (e.g. embeddings) # "Dummy" is used when the reindexed # ActivationsAndMetadata object does not require the activation from the original object at that layer index, # for example if it's the input to a derived scalar computation that doesn't require every activation at every # layer index pass class IdentityLayerIndexer(LayerIndexer): """Sometimes computing derived scalar D at layer L requires Scalar S from layer L, and Scalar T from layer L. In this case no changes are needed to the layer indices of the activations_by_layer_index in the ActivationsAndMetadata object. This is used for such cases (it does a no-op). """ def __init__(self) -> None: pass def __call__(self, layer_indices: list[LayerIndex]) -> list[LayerIndex]: return layer_indices def __repr__(self) -> str: return "IdentityLayerIndexer()" class OffsetLayerIndexer(LayerIndexer): """Sometimes computing derived scalar D at layer L requires Scalar S from layer L, and Scalar T from layer L-1. This is used for populating an ActivationsAndMetadata object at each layer index with references to the activations at the previous layer index. """ def __init__(self, layer_index_offset: int) -> None: self.layer_index_offset = layer_index_offset def __call__(self, layer_indices: list[LayerIndex]) -> list[LayerIndex | Literal["Dummy"]]: def _dummy_if_invalid( layer_index: LayerIndex, valid_indices: set[LayerIndex] ) -> LayerIndex | Literal["Dummy"]: if layer_index in valid_indices: return layer_index else: # this value represents the fact that the layer index is not needed for this computation # callers are free to use a dummy tensor in place of the activations at this layer index, # knowing that downstream DST calculations are intended to be independent of the activation # tensor provided at this layer index return "Dummy" assert all(layer_index is not None for layer_index in layer_indices) # source_layer_indices satisfy: # target_layer_indices = layer_indices # for target_layer_index, source_layer_index in zip(target_layer_indices, source_layer_indices): # target_activations_by_layer_index[target_layer_index] = source_activations_by_layer_index[source_layer_index] # (or a dummy tensor, if the index is "Dummy") source_layer_indices = [layer_index + self.layer_index_offset for layer_index in layer_indices] # type: ignore # 'invalid' layer indices are those not in starting_layer_indices; starting_layer_indices mapped to unneeded layer indices are considered "Unneeded" return [ _dummy_if_invalid(source_layer_index, set(layer_indices)) for source_layer_index in source_layer_indices ] def __repr__(self) -> str: return f"OffsetLayerIndexer(layer_index_offset={self.layer_index_offset})" class StaticLayerIndexer(LayerIndexer, ABC): """A subset of LayerIndexers have a single layer_index associated with them. This gives those LayerIndexers a common abstract interface.""" layer_index: LayerIndex def __call__(self, layer_indices: list[LayerIndex]) -> list[LayerIndex]: # this says to use the same activation tensor at every layer index requested; each layer index # is mapped to the same constant (or None) layer index return [self.layer_index for _ in layer_indices] class ConstantLayerIndexer(StaticLayerIndexer): """Sometimes computing derived scalar D at layer L requires Scalar S from layer L, and Scalar T from layer C (independent of L). This is used for populating an ActivationsAndMetadata object with references to the same activation tensor (from layer C) at every layer index L. """ def __init__(self, constant_layer_index: int) -> None: self.layer_index = constant_layer_index def __repr__(self) -> str: return f"ConstantLayerIndexer(constant_layer_index={self.layer_index})" class NoLayersLayerIndexer(StaticLayerIndexer): """Sometimes computing derived scalar D at layer L requires Scalar S from layer L, and Scalar T which doesn't have layers. This is used for populating an ActivationsAndMetadata object with references to the same activation tensor (from a location type with no layers, i.e. at the index None) at every layer index L.""" def __init__(self) -> None: self.layer_index = None def __repr__(self) -> str: return "NoLayersLayerIndexer()" DEFAULT_LAYER_INDEXER = IdentityLayerIndexer() def precedes_final_layer( derived_scalar_location_within_layer: LocationWithinLayer | None, derived_scalar_layer_index: LayerIndex, final_residual_location_within_layer: LocationWithinLayer | None, final_residual_layer_index: LayerIndex, ) -> bool: """Returns True if the derived scalar at a given layer_index precedes the final residual stream derived scalar at a specified layer_index""" # return True if the derived scalar at a given layer_index precedes the final residual stream layer_index if derived_scalar_layer_index is None: return True # activations with no layer_index are assumed to precede # all activations with layer_index; note that according to current conventions # this is true for all residual stream activations (not true e.g. for token logits) elif final_residual_layer_index is None: assert derived_scalar_layer_index is not None return False # activations with layer_index precede activations with no layer_index elif derived_scalar_layer_index < final_residual_layer_index: return True elif derived_scalar_layer_index == final_residual_layer_index: if derived_scalar_location_within_layer is None: raise ValueError( "derived_scalar_location_within_layer must be provided in case of equal layer indices" ) if final_residual_location_within_layer is None: raise ValueError( "final_residual_location_within_layer must be provided in case of equal layer indices" ) if derived_scalar_location_within_layer < final_residual_location_within_layer: # location_within_layer inherits from int; therefore they are straightforwardly comparable return True else: return False else: assert derived_scalar_layer_index > final_residual_layer_index return False def get_location_within_layer_for_dst( dst: DerivedScalarType, dst_config: DstConfig, ) -> LocationWithinLayer | None: """Determines the location within a layer for DSTs which are not associated with an activation location type, or whose location within a layer depends on information in the DstConfig (e.g. autoencoder related DSTs). Defining new direct write related DSTs may require additional entries here.""" if dst.location_within_layer is not None: # this might be determinable from the DST alone, in which case return it right away return dst.location_within_layer else: match dst.node_type: case ( NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT ): autoencoder_context = dst_config.get_autoencoder_context(dst.node_type) if autoencoder_context is not None: return autoencoder_context.dst.location_within_layer else: return None case NodeType.RESIDUAL_STREAM_CHANNEL: match dst: case DerivedScalarType.ATTN_WRITE: return LocationWithinLayer.ATTN case DerivedScalarType.PREVIOUS_LAYER_RESID_POST_MLP: return LocationWithinLayer.END_OF_PREV_LAYER case _: return None case _: return None def get_previous_residual_dst_for_node_type( node_type: NodeType, autoencoder_dst: DerivedScalarType | None, ) -> DerivedScalarType: """This function returns the DerivedScalarType of the residual stream that precedes the node type specified. autoencoder_context is only required if node_type is NodeType.ONLINE_AUTOENCODER. """ match node_type: case NodeType.ATTENTION_HEAD: return DerivedScalarType.PREVIOUS_LAYER_RESID_POST_MLP case NodeType.MLP_NEURON: return DerivedScalarType.RESID_POST_ATTN case ( NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT ): assert autoencoder_dst is not None, node_type match autoencoder_dst.node_type: case NodeType.RESIDUAL_STREAM_CHANNEL: match autoencoder_dst: case DerivedScalarType.RESID_DELTA_ATTN: return get_previous_residual_dst_for_node_type( node_type=NodeType.ATTENTION_HEAD, autoencoder_dst=None, ) case DerivedScalarType.RESID_DELTA_MLP: return get_previous_residual_dst_for_node_type( node_type=NodeType.MLP_NEURON, autoencoder_dst=None, ) case _: raise NotImplementedError(autoencoder_dst) case _: return get_previous_residual_dst_for_node_type( node_type=autoencoder_dst.node_type, autoencoder_dst=None, ) case _: raise NotImplementedError(node_type) def get_activation_index_for_residual_dst( dst: DerivedScalarType, layer_index: int, ) -> ActivationIndex: """ This returns an ActivationIndex corresponding to a residual stream activation location at a given layer_index; handles the indexing logic in the case of PREVIOUS_LAYER_RESID_POST_MLP. The ActivationIndex returned corresponds to the entire residual stream activation tensor for the layer. """ assert dst.node_type == NodeType.RESIDUAL_STREAM_CHANNEL match dst: case DerivedScalarType.PREVIOUS_LAYER_RESID_POST_MLP: if layer_index == 0: return ActivationIndex( activation_location_type=ActivationLocationType.RESID_POST_EMBEDDING, layer_index=None, tensor_indices=(), pass_type=PassType.FORWARD, ) else: return ActivationIndex( activation_location_type=ActivationLocationType.RESID_POST_MLP, layer_index=layer_index - 1, tensor_indices=(), pass_type=PassType.FORWARD, ) case DerivedScalarType.RESID_POST_MLP: return ActivationIndex( activation_location_type=ActivationLocationType.RESID_POST_MLP, layer_index=layer_index, tensor_indices=(), pass_type=PassType.FORWARD, ) case DerivedScalarType.RESID_POST_ATTN: return ActivationIndex( activation_location_type=ActivationLocationType.RESID_POST_ATTN, layer_index=layer_index, tensor_indices=(), pass_type=PassType.FORWARD, ) case _: raise NotImplementedError(dst)