neuron_explainer/activations/derived_scalars/scalar_deriver.py (357 lines of code) (raw):

"""This file contains the primary code for the ScalarDeriver class.""" import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable import torch from neuron_explainer.activations.derived_scalars.activations_and_metadata import ( ActivationsAndMetadata, RawActivationStore, ) from neuron_explainer.activations.derived_scalars.config import DstConfig from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.locations import ( DEFAULT_LAYER_INDEXER, LayerIndexer, NoLayersLayerIndexer, StaticLayerIndexer, get_location_within_layer_for_dst, ) from neuron_explainer.models.model_component_registry import ( ActivationLocationType, ActivationLocationTypeAndPassType, Dimension, LayerIndex, LocationWithinLayer, PassType, ) ### SHARED CODE FOR DERIVING SCALARS FROM ACTIVATIONS ### @dataclass(frozen=True) class DerivedScalarTypeAndPassType: dst: DerivedScalarType pass_type: PassType class ScalarSource(ABC): pass_type: PassType layer_indexer: LayerIndexer @property @abstractmethod def exists_by_default(self) -> bool: # returns True if the activation is instantiated by default in a normal transformer forward pass # this is False for activations related to autoencoders or for non-trivial derived scalars pass @property @abstractmethod def dst(self) -> DerivedScalarType: pass @property def dst_and_pass_type(self) -> "DerivedScalarTypeAndPassType": return DerivedScalarTypeAndPassType( self.dst, self.pass_type, ) @property @abstractmethod def sub_activation_location_type_and_pass_types( self, ) -> tuple[ActivationLocationTypeAndPassType, ...]: pass @property @abstractmethod def location_within_layer(self) -> LocationWithinLayer | None: pass @property def layer_index(self) -> LayerIndex: """Convenience method to get the single layer index associated with this ScalarSource, if such a single layer index exists. Throws an error if it does not.""" assert isinstance(self.layer_indexer, StaticLayerIndexer), ( self.layer_indexer, "ScalarSource.layer_index should only be called for ScalarSource StaticLayerIndexer", ) return self.layer_indexer.layer_index @abstractmethod def derive_from_raw( self, raw_activation_store: RawActivationStore, desired_layer_indices: ( list[LayerIndex] | None ), # indicates layer indices to keep; None indicates keep all ) -> ActivationsAndMetadata: """Given raw activations, derive the scalar value. desired_layer_indices is a list of layer indices to include in the output; None indicates all layers, while [None] indicates activations not indexed by layers (e.g. from the embedding).""" pass # note that this class, inheriting from ActivationLocationTypeAndPassType, becomes a # base class. This needs to be a separate object from ActivationLocationTypeAndPassType, # and located within this file, because ScalarSource needs to know about DerivedScalarTypes, # which are defined within the derived_scalars/ directory class RawScalarSource(ActivationLocationTypeAndPassType, ScalarSource): def __init__( self, activation_location_type: ActivationLocationType, pass_type: PassType, layer_indexer: LayerIndexer = DEFAULT_LAYER_INDEXER, ) -> None: super().__init__(activation_location_type, pass_type) self.layer_indexer = layer_indexer if activation_location_type.has_no_layers: assert isinstance(layer_indexer, NoLayersLayerIndexer), self @property def dst(self) -> DerivedScalarType: return DerivedScalarType.from_activation_location_type(self.activation_location_type) @property def sub_activation_location_type_and_pass_types( self, ) -> tuple[ActivationLocationTypeAndPassType, ...]: return (self.activation_location_type_and_pass_type,) @property def exists_by_default(self) -> bool: return self.activation_location_type.exists_by_default @property def location_within_layer(self) -> LocationWithinLayer | None: return self.activation_location_type.location_within_layer @property def activation_location_type_and_pass_type(self) -> ActivationLocationTypeAndPassType: return ActivationLocationTypeAndPassType(self.activation_location_type, self.pass_type) def derive_from_raw( self, raw_activation_store: RawActivationStore, desired_layer_indices: ( list[LayerIndex] | None ), # indicates layer indices to keep; None indicates keep all ) -> ActivationsAndMetadata: return raw_activation_store.get_activations_and_metadata( self.activation_location_type, self.pass_type, ).apply_layer_indexer(self.layer_indexer, desired_layer_indices) class DerivedScalarSource(ScalarSource): scalar_deriver: "ScalarDeriver" def __init__( self, scalar_deriver: "ScalarDeriver", pass_type: PassType, layer_indexer: LayerIndexer = DEFAULT_LAYER_INDEXER, ) -> None: self.scalar_deriver = scalar_deriver self.pass_type = pass_type self.layer_indexer = layer_indexer @property def exists_by_default(self) -> bool: return False @property def dst(self) -> DerivedScalarType: return self.scalar_deriver.dst @property def sub_activation_location_type_and_pass_types( self, ) -> tuple[ActivationLocationTypeAndPassType, ...]: return self.scalar_deriver.get_sub_activation_location_type_and_pass_types() @property def location_within_layer(self) -> LocationWithinLayer | None: return self.scalar_deriver.location_within_layer def derive_from_raw( self, raw_activation_store: RawActivationStore, desired_layer_indices: ( list[LayerIndex] | None ), # indicates layer indices to keep; None indicates keep all ) -> ActivationsAndMetadata: return self.scalar_deriver.derive_from_raw( raw_activation_store, self.pass_type ).apply_layer_indexer(self.layer_indexer, desired_layer_indices=desired_layer_indices) @dataclass(frozen=True) class ScalarDeriver: """Contains the information necessary for specifying some function of one or more activations, (this function can be as simple as the identity function). This includes: what activations are required to compute it; a function that takes in ActivationsAndMetadata for each of those activations and returns a ActivationsAndMetadata for the derived scalar; and a function that returns the shape you expect the derived scalar to have for each token (e.g. one float per attention head, one float per layer, etc.). The function for computing this derived scalar on the forward pass can be different from the function for computing its gradient on the backward pass, so the pass type must also be an argument to the function that computes the scalar. A HookLocationType describes the type of activation that is saved during inference, and a ScalarDeriver describes the type of "derived" scalar computed from those activations after they are read from disk. In the simplest case, a derived scalar can be computed directly from the saved activations with an identity transformation (e.g. a single MLP activation is saved during inference, and the derived scalar is the same MLP activation).""" dst: DerivedScalarType """ Dataclass with fields needed to construct a ScalarDeriver for this DerivedScalarType; e.g. derived scalars computed using model weights will require at minimum the model_name to load the weights. """ dst_config: DstConfig """ Contains ActivationLocationTypes or other ScalarDerivers, and corresponding pass directions (forward or backward) that are required to compute this derived scalar type. These are loaded from disk and passed to the tensor_calculate_derived_scalar_fn as a single tuple argument. """ sub_scalar_sources: tuple[ScalarSource, ...] """ A function that takes a tuple of tensors, a layer index, and a pass type, and returns a tensor containing the derived scalar values. layer_index can be None in case of activation location types that don't have layer indices, like embeddings. """ tensor_calculate_derived_scalar_fn: Callable[ [tuple[torch.Tensor, ...], LayerIndex, PassType], torch.Tensor ] """In cases where a ScalarDeriver is a transform applied to another scalar deriver, the location within a layer associated with the resulting scalar deriver is taken to be the same as the location within a layer associated with the original scalar deriver. See definition of LocationWithinLayer in model_component_registry.py for more details.""" _specified_location_within_layer: LocationWithinLayer | None = None @property def device_for_raw_activations(self) -> torch.device: """Which device to read raw activations onto.""" return self.dst_config.get_device() @property def shape_of_activation_per_token_spec(self) -> tuple[Dimension, ...]: # first dimension is num_sequence_tokens; this can be either the literal number of tokens in a sequence or # the number of token pairs in a sequence return self.dst.shape_spec_per_token_sequence[1:] @property def location_within_layer(self) -> LocationWithinLayer | None: """An activation location type at a topologically equivalent point in the network, in terms of which residual stream locations precede and follow it.""" specified_location_within_layer = self._specified_location_within_layer dst_location_within_layer = get_location_within_layer_for_dst(self.dst, self.dst_config) if specified_location_within_layer is not None and dst_location_within_layer is not None: assert specified_location_within_layer == dst_location_within_layer consensus_location_within_layer = ( specified_location_within_layer or dst_location_within_layer ) return consensus_location_within_layer def _check_dst_and_pass_types( self, activation_data_tuple: tuple[ActivationsAndMetadata, ...] ) -> None: """Check that the derived scalar types and pass types of the raw activations match the order of the dsts and pass types in self.get_sub_dst_and_pass_types(). """ assert len(activation_data_tuple) == len(self.get_sub_dst_and_pass_types()), ( [activation_data.dst for activation_data in activation_data_tuple], [ sub_dst_and_pass_type.dst for sub_dst_and_pass_type in self.get_sub_dst_and_pass_types() ], ) for activation_data, sub_dst_and_pass_type in zip( activation_data_tuple, self.get_sub_dst_and_pass_types() ): assert ( activation_data.dst == sub_dst_and_pass_type.dst ), f"{activation_data.dst=}, {sub_dst_and_pass_type.dst=}" assert activation_data.pass_type == sub_dst_and_pass_type.pass_type, ( f"{self.dst=}, " f"{activation_data.dst=}, " f"{activation_data.pass_type=}, {sub_dst_and_pass_type.pass_type=}" ) assert activation_data.pass_type == sub_dst_and_pass_type.pass_type return def activations_and_metadata_calculate_derived_scalar_fn( self, activation_data_tuple: tuple[ActivationsAndMetadata, ...], pass_type: PassType ) -> ActivationsAndMetadata: self._check_dst_and_pass_types(activation_data_tuple) for activation_data in activation_data_tuple: assert len(activation_data.activations_by_layer_index) > 0, ( f"{activation_data.activations_by_layer_index=}" f"{activation_data.dst=}" f"{activation_data.pass_type=}" ) activation_data = activation_data_tuple[0] filtered_activation_data = activation_data.filter_layers( layer_indices=self.dst_config.layer_indices ) if len(activation_data_tuple) == 1: def _calculate_derived_scalar_fn( activations: torch.Tensor, layer_index: LayerIndex, ) -> torch.Tensor: return self.tensor_calculate_derived_scalar_fn( (activations,), layer_index, pass_type ) return filtered_activation_data.apply_layerwise_transform_fn_to_activations( layerwise_transform_fn=_calculate_derived_scalar_fn, output_dst=self.dst, output_pass_type=pass_type, ) elif len(activation_data_tuple) >= 2: def _calculate_multi_arg_derived_scalar_fn( *args: torch.Tensor, layer_index: LayerIndex, ) -> torch.Tensor: return self.tensor_calculate_derived_scalar_fn(tuple(args), layer_index, pass_type) other_filtered_activation_data_tuple = tuple( activation_data.filter_layers(layer_indices=self.dst_config.layer_indices) for activation_data in activation_data_tuple[1:] ) return filtered_activation_data.apply_layerwise_transform_fn_to_multiple_activations( # care should be taken in a dictionary comprehension of callables that the # variables (i.e. layer_index) are bound at time of creation, not at time of execution # partial accomplishes this layerwise_transform_fn=_calculate_multi_arg_derived_scalar_fn, others=other_filtered_activation_data_tuple, output_dst=self.dst, output_pass_type=pass_type, ) else: raise NotImplementedError( f"ScalarDeriver.activations_and_metadata_calculate_derived_scalar_fn not implemented for " f"{len(activation_data_tuple)=}" ) def derive_from_raw( self, raw_activation_store: RawActivationStore, pass_type: PassType, ) -> ActivationsAndMetadata: desired_layer_indices = None sub_activations_list = [] for sub_scalar_source in self.get_sub_scalar_sources(): sub_activation_data = sub_scalar_source.derive_from_raw( raw_activation_store, desired_layer_indices=desired_layer_indices ) sub_activations_list.append(sub_activation_data) if len(sub_activations_list) == 1: desired_layer_indices = list(sub_activations_list[0].layer_indices) return self.activations_and_metadata_calculate_derived_scalar_fn( tuple(sub_activations_list), pass_type ) def to_serializable_dict(self) -> dict[str, Any]: return { "dst": self.dst, "dst_config": self.dst_config, } def get_sub_dst_and_pass_types(self) -> tuple[DerivedScalarTypeAndPassType, ...]: return tuple( sub_scalar_source.dst_and_pass_type for sub_scalar_source in self.sub_scalar_sources ) def get_sub_scalar_sources(self) -> tuple[ScalarSource, ...]: return self.sub_scalar_sources def get_sub_activation_location_type_and_pass_types( self, ) -> tuple[ActivationLocationTypeAndPassType, ...]: sub_activation_location_type_and_pass_types_list = [] for scalar_source in self.get_sub_scalar_sources(): sub_activation_location_type_and_pass_types_list.extend( list(scalar_source.sub_activation_location_type_and_pass_types) ) return tuple(sub_activation_location_type_and_pass_types_list) @property def n_input_tensors(self) -> int: # the number of arguments expected by the top-level function. Note that this is not necessarily # the same as the number of sub_activation_location_type_and_pass_types; some of these might be # consumed by lower-level ScalarDerivers, and combined into single tensors passed to the top-level # function. return len(self.get_sub_dst_and_pass_types()) @property def n_total_required_tensors(self) -> int: # the number of tensors required to compute the derived scalar, including those that are # passed to lower-level ScalarDerivers return len(self.get_sub_activation_location_type_and_pass_types()) @property def derivable_pass_types(self) -> tuple[PassType, ...]: # ScalarDerivers are configurable to support either only computing a # scalar on the forward pass, or computing it on both the forward and # the backward pass. Supporting the backward pass requires more kinds # of raw activations in general. if self.dst_config.derive_gradients: return (PassType.FORWARD, PassType.BACKWARD) else: return (PassType.FORWARD,) def apply_transform_fn_to_output( self, transform_fn: Callable[[torch.Tensor], torch.Tensor], pass_type_to_transform: PassType, output_dst: DerivedScalarType, ) -> "ScalarDeriver": """Converts one ScalarDeriver to another, by applying a tensor -> tensor function to the output. The tensor -> tensor function takes a tensor, a layer index, and a pass type, and returns a tensor, so that it can depend on layer and pass type.""" def layerwise_transform_fn( tensor: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: return transform_fn(tensor) return self.apply_layerwise_transform_fn_to_output( layerwise_transform_fn=layerwise_transform_fn, pass_type_to_transform=pass_type_to_transform, output_dst=output_dst, ) def apply_layerwise_transform_fn_to_output( self, layerwise_transform_fn: Callable[[torch.Tensor, LayerIndex, PassType], torch.Tensor], pass_type_to_transform: PassType, output_dst: DerivedScalarType, ) -> "ScalarDeriver": """Converts one ScalarDeriver to another, by applying a tensor -> tensor function to the output. The tensor -> tensor function takes a tensor, a layer index, and a pass type, and returns a tensor, so that it can depend on layer and pass type.""" sub_scalar_sources = (DerivedScalarSource(self, pass_type=pass_type_to_transform),) def tensor_calculate_derived_scalar_fn( activation_data_tuple: tuple[torch.Tensor, ...], layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert len(activation_data_tuple) == 1 return layerwise_transform_fn(activation_data_tuple[0], layer_index, pass_type) return dataclasses.replace( self, dst=output_dst, sub_scalar_sources=sub_scalar_sources, tensor_calculate_derived_scalar_fn=tensor_calculate_derived_scalar_fn, _specified_location_within_layer=self.location_within_layer, ) def apply_layerwise_transform_fn_to_output_and_other_tensor( self, layerwise_transform_fn: Callable[..., torch.Tensor], pass_type_to_transform: PassType, output_dst: DerivedScalarType, other_scalar_source: ScalarSource, ) -> "ScalarDeriver": """Converts one ScalarDeriver to another, by applying a two tensor -> tensor function to the output + an additional activation tensor. The tensor -> tensor function takes two tensors, a layer index, and a pass type, and returns a tensor, so that it can depend on layer and pass type.""" sub_scalar_sources = ( DerivedScalarSource( self, pass_type=pass_type_to_transform, layer_indexer=DEFAULT_LAYER_INDEXER ), other_scalar_source, ) def tensor_calculate_derived_scalar_fn( activation_data_tuple: tuple[torch.Tensor, ...], layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert len(activation_data_tuple) == 2, [t.shape for t in activation_data_tuple] return layerwise_transform_fn(*activation_data_tuple, layer_index, pass_type) return dataclasses.replace( self, dst=output_dst, sub_scalar_sources=sub_scalar_sources, tensor_calculate_derived_scalar_fn=tensor_calculate_derived_scalar_fn, _specified_location_within_layer=self.location_within_layer, )