neuron_explainer/activations/derived_scalars/raw_activations.py (306 lines of code) (raw):

""" This file contains code to make scalar derivers for scalar types that are 1:1 with an ActivationLocationType. """ from functools import partial from typing import Callable import torch from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.locations import ( IdentityLayerIndexer, LayerIndexer, NoLayersLayerIndexer, ) from neuron_explainer.activations.derived_scalars.scalar_deriver import ( DstConfig, PassType, RawScalarSource, ScalarDeriver, ) from neuron_explainer.models.model_component_registry import ( ActivationLocationType, LayerIndex, NodeType, ) def get_scalar_sources_for_activation_location_types( activation_location_type: ActivationLocationType, derive_gradients: bool, ) -> tuple[RawScalarSource, ...]: if activation_location_type.has_no_layers: layer_indexer: LayerIndexer = NoLayersLayerIndexer() else: layer_indexer = IdentityLayerIndexer() if derive_gradients: return ( RawScalarSource( activation_location_type=activation_location_type, pass_type=PassType.FORWARD, layer_indexer=layer_indexer, ), RawScalarSource( activation_location_type=activation_location_type, pass_type=PassType.BACKWARD, layer_indexer=layer_indexer, ), ) else: return ( RawScalarSource( activation_location_type=activation_location_type, pass_type=PassType.FORWARD, layer_indexer=layer_indexer, ), ) def no_op_tensor_calculate_derived_scalar_fn( raw_activation_data_tuple: tuple[torch.Tensor, ...], layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: """ This either: converts a length 1 tuple of tensors into a single tensor; pass_type is asserted to be PassType.FORWARD or converts a length 2 tuple of tensors, one for the forward pass and one for the backward pass, into the appropriate one of those two objects, depending on the pass_type argument. """ if len(raw_activation_data_tuple) == 1: # in this case, only the activations at the relevant ActivationLocationType have been loaded from disk assert pass_type == PassType.FORWARD raw_activation_data = raw_activation_data_tuple[0] return raw_activation_data elif len(raw_activation_data_tuple) == 2: # in this case, both the activations and gradients at the relevant ActivationLocationType have been loaded from disk raw_activation_data, raw_gradient_data = raw_activation_data_tuple if pass_type == PassType.FORWARD: return raw_activation_data elif pass_type == PassType.BACKWARD: return raw_gradient_data else: raise ValueError(f"Unknown {pass_type=}") else: raise ValueError(f"Unknown {raw_activation_data_tuple=}") def make_scalar_deriver_factory_for_activation_location_type( activation_location_type: ActivationLocationType, ) -> Callable[[DstConfig], ScalarDeriver]: """ This is for DerivedScalarType's 1:1 with a ActivationLocationType, which can be generated from just the ActivationLocationType and no additional information. """ def make_scalar_deriver_fn( dst_config: DstConfig, ) -> ScalarDeriver: sub_scalar_sources = get_scalar_sources_for_activation_location_types( activation_location_type, dst_config.derive_gradients ) return ScalarDeriver( dst=DerivedScalarType.from_activation_location_type(activation_location_type), dst_config=dst_config, sub_scalar_sources=sub_scalar_sources, tensor_calculate_derived_scalar_fn=no_op_tensor_calculate_derived_scalar_fn, ) return make_scalar_deriver_fn def make_scalar_deriver_factory_for_act_times_grad( activation_location_type: ActivationLocationType, dst: DerivedScalarType, ) -> Callable[[DstConfig], ScalarDeriver]: def make_act_times_grad_scalar_deriver( dst_config: DstConfig, ) -> ScalarDeriver: assert not dst_config.derive_gradients, "Gradients not defined for act times grad" if activation_location_type.has_no_layers: layer_indexer: LayerIndexer = NoLayersLayerIndexer() else: layer_indexer = IdentityLayerIndexer() sub_scalar_sources = ( RawScalarSource( activation_location_type=activation_location_type, pass_type=PassType.FORWARD, layer_indexer=layer_indexer, ), # activations RawScalarSource( activation_location_type=activation_location_type, pass_type=PassType.BACKWARD, layer_indexer=layer_indexer, ), # gradients ) def _act_times_grad_tensor_calculate_derived_scalar_fn( raw_activation_data_tuple: tuple[torch.Tensor, ...], layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD, "Backward pass not defined for act times grad" assert len(raw_activation_data_tuple) == 2 raw_activation_data, raw_gradient_data = raw_activation_data_tuple return raw_activation_data * raw_gradient_data return ScalarDeriver( dst=dst, dst_config=dst_config, sub_scalar_sources=sub_scalar_sources, tensor_calculate_derived_scalar_fn=_act_times_grad_tensor_calculate_derived_scalar_fn, ) return make_act_times_grad_scalar_deriver def check_write_tensor_device_matches( scalar_deriver: ScalarDeriver, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor], ) -> None: write_matrix_device = next(iter(write_tensor_by_layer_index.values())).device assert scalar_deriver.device_for_raw_activations == write_matrix_device, ( scalar_deriver.dst, scalar_deriver.device_for_raw_activations, write_matrix_device, ) def convert_scalar_deriver_to_write_norm( scalar_deriver: ScalarDeriver, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor], output_dst: DerivedScalarType, ) -> ScalarDeriver: """ Converts a scalar deriver for a scalar type that is 1:1 with an ActivationLocationType to a scalar deriver for the write norm for each neuron at each token. """ check_write_tensor_device_matches( scalar_deriver, write_tensor_by_layer_index, ) write_norm_by_layer_index = { layer_index_: write_tensor_by_layer_index[layer_index_].norm(dim=-1) # type: ignore for layer_index_ in write_tensor_by_layer_index.keys() } def multiply_by_write_norm( activations: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD, "Backward pass not defined for write norm" assert ( layer_index in write_tensor_by_layer_index ), f"{layer_index=} not in {write_tensor_by_layer_index.keys()=} for {output_dst=}" return activations * write_norm_by_layer_index[layer_index] return scalar_deriver.apply_layerwise_transform_fn_to_output( multiply_by_write_norm, pass_type_to_transform=PassType.FORWARD, output_dst=output_dst, ) def convert_scalar_deriver_to_write( scalar_deriver: ScalarDeriver, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor], output_dst: DerivedScalarType, ) -> ScalarDeriver: """Converts a scalar deriver for a scalar type that is 1:1 with an ActivationLocationType to a scalar deriver for the write vector of the layer at each token.""" check_write_tensor_device_matches( scalar_deriver, write_tensor_by_layer_index, ) def multiply_by_write( activations: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD, "Backward pass not defined for write" assert ( layer_index in write_tensor_by_layer_index ), f"{layer_index=} not in {write_tensor_by_layer_index.keys()=}" return torch.einsum( "ta,ao->to", activations, write_tensor_by_layer_index[layer_index], # type: ignore ) return scalar_deriver.apply_layerwise_transform_fn_to_output( multiply_by_write, pass_type_to_transform=PassType.FORWARD, output_dst=output_dst, ) def convert_scalar_deriver_to_write_vector( scalar_deriver: ScalarDeriver, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor], output_dst: DerivedScalarType, ) -> ScalarDeriver: """ Converts a scalar deriver for a scalar type that is 1:1 with an ActivationLocationType to a scalar deriver for the write vector of the layer at each token. Must be a scalar type that is related to the residual stream basis by a straightforward matmul (e.g. MLP post-activations are related to the residual stream basis by WeightLocationType.MLP_TO_RESIDUAL). """ check_write_tensor_device_matches( scalar_deriver, write_tensor_by_layer_index, ) assert scalar_deriver.dst.node_type in { NodeType.MLP_NEURON, NodeType.V_CHANNEL, NodeType.AUTOENCODER_LATENT, NodeType.MLP_AUTOENCODER_LATENT, NodeType.ATTENTION_AUTOENCODER_LATENT, } def multiply_by_write_vector( activations: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD, "Backward pass not defined for write" assert ( layer_index in write_tensor_by_layer_index ), f"{layer_index=} not in {write_tensor_by_layer_index.keys()=}" return torch.einsum( "ta,ao->tao", activations, write_tensor_by_layer_index[layer_index], # type: ignore ) return scalar_deriver.apply_layerwise_transform_fn_to_output( multiply_by_write_vector, pass_type_to_transform=PassType.FORWARD, output_dst=output_dst, ) def truncate_to_expected_shape( tensor: torch.Tensor, expected_shape: list[int | None], ) -> torch.Tensor: """ This asserts that the tensor has the expected shape, and optionally truncates it to that shape. None in expected_shape means that dimension is not checked. """ if expected_shape is None: return tensor for dim, real_size, expected_size in zip( range(len(expected_shape)), tensor.shape, expected_shape ): if expected_size is not None: assert ( real_size >= expected_size ), f"Dimension {dim} of tensor has size {real_size} but expected size {expected_size}" tensor = tensor.narrow(dim, 0, expected_size) return tensor def truncate_to_expected_shape_tensor_calculate_derived_scalar_fn( raw_activation_data_tuple: tuple[torch.Tensor, ...], layer_index: LayerIndex, pass_type: PassType, expected_shape: list[int | None], ) -> torch.Tensor: """This either: converts a length 1 tuple of tensors into a single tensor; pass_type is asserted to be PassType.FORWARD or converts a length 2 tuple of tensors, one for the forward pass and one for the backward pass, into the appropriate one of those two objects, depending on the pass_type argument.""" if len(raw_activation_data_tuple) == 1: # in this case, only the activations at the relevant ActivationLocationType have been loaded from disk assert pass_type == PassType.FORWARD raw_activation_data = raw_activation_data_tuple[0] raw_activation_data = truncate_to_expected_shape( raw_activation_data, expected_shape, ) return raw_activation_data elif len(raw_activation_data_tuple) == 2: # in this case, both the activations and gradients at the relevant ActivationLocationType have been loaded from disk raw_activation_data, raw_gradient_data = raw_activation_data_tuple if pass_type == PassType.FORWARD: raw_activation_data = truncate_to_expected_shape( raw_activation_data, expected_shape, ) return raw_activation_data elif pass_type == PassType.BACKWARD: raw_gradient_data = truncate_to_expected_shape( raw_gradient_data, expected_shape, ) return raw_gradient_data else: raise ValueError(f"Unknown {pass_type=}") else: raise ValueError(f"Unknown {raw_activation_data_tuple=}") # TODO: this entire function should be simplified or deleted? # Can possibly just use make_scalar_deriver_factory_for_activation_location_type(ActivationLocationType.LOGITS) # in the one place it is called def make_truncate_to_expected_shape_scalar_deriver_factory_for_dst( dst: DerivedScalarType, ) -> Callable[[DstConfig], ScalarDeriver]: """ This is for DerivedScalarType's 1:1 with a ActivationLocationType, which can be generated from just the ActivationLocationType and no additional information. """ untruncated_activation_location_type_by_truncated_dst = { DerivedScalarType.LOGITS: ActivationLocationType.LOGITS, } assert ( dst in untruncated_activation_location_type_by_truncated_dst ), f"No untruncated ActivationLocationType for this DerivedScalarType: {dst}" activation_location_type = untruncated_activation_location_type_by_truncated_dst[dst] def make_scalar_deriver_fn( dst_config: DstConfig, ) -> ScalarDeriver: sub_scalar_sources = get_scalar_sources_for_activation_location_types( activation_location_type, dst_config.derive_gradients ) model_context = dst_config.get_model_context() expected_dimensions = dst.shape_spec_per_token_sequence expected_shape: list[int | None] = [] for dimension in expected_dimensions: if dimension.is_model_intrinsic: expected_shape.append(model_context.get_dim_size(dimension)) else: expected_shape.append(None) return ScalarDeriver( dst=dst, dst_config=dst_config, sub_scalar_sources=sub_scalar_sources, tensor_calculate_derived_scalar_fn=partial( truncate_to_expected_shape_tensor_calculate_derived_scalar_fn, expected_shape=expected_shape, ), ) return make_scalar_deriver_fn