neuron_explainer/activations/derived_scalars/direct_effects.py (416 lines of code) (raw):

""" "Direct effects" of one node on another, or on the loss, are defined by first computing the gradient of the downstream node's activation with respect to the residual stream immediately preceding the downstream node. We then compute the inner product of this gradient with the write vector of the upstream node. If the upstream node is in the residual stream basis, then it is considered to be its own "write vector" for this purpose. This file contains code for performing the computation described above, for upstream nodes of various types. """ import dataclasses from typing import Callable import torch from neuron_explainer.activations.derived_scalars.config import TraceConfig from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import AttnSubNodeIndex, NodeIndex from neuron_explainer.activations.derived_scalars.locations import ( ConstantLayerIndexer, get_previous_residual_dst_for_node_type, precedes_final_layer, ) from neuron_explainer.activations.derived_scalars.raw_activations import ( check_write_tensor_device_matches, ) from neuron_explainer.activations.derived_scalars.reconstituted import make_apply_attn_V_act from neuron_explainer.activations.derived_scalars.reconstituter_class import ( Reconstituter, make_no_backward_pass_scalar_source_for_final_residual_grad, ) from neuron_explainer.activations.derived_scalars.scalar_deriver import ( DerivedScalarSource, DstConfig, RawScalarSource, ScalarDeriver, ScalarSource, ) from neuron_explainer.activations.derived_scalars.write_tensors import ( get_attn_write_tensor_by_layer_index, get_autoencoder_write_tensor_by_layer_index, get_mlp_write_tensor_by_layer_index, ) from neuron_explainer.models.model_component_registry import ( ActivationLocationType, LayerIndex, NodeType, PassType, ) from neuron_explainer.models.model_context import ModelContext def make_write_to_direction_tensor_fn( node_type: NodeType, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor] | None, layer_precedes_direction_layer_fn: Callable[[LayerIndex], bool], ) -> Callable[[torch.Tensor, torch.Tensor, LayerIndex, PassType], torch.Tensor]: """ To convert an "activation" tensor to a "projection of write to direction" tensor, we need to convert the "activation" to the residual stream basis (using a write tensor) if it is not already, and then project to the direction of interest. This function constructs the appropriate tensor operation to perform this projection based on the node_type of the activation tensor (passed as an argument). The write_tensor_by_layer_index argument defines the conversion to the residual stream basis, and the layer_precedes_direction_layer_fn argument is assumed to return True iff the derived scalar at a given layer index is upstream of the direction of interest. """ match node_type: case NodeType.RESIDUAL_STREAM_CHANNEL: assert write_tensor_by_layer_index is None def inner_product_with_residual( residual: torch.Tensor, direction: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: # (num_sequence_tokens, 1) assert pass_type == PassType.FORWARD assert residual.ndim == 2 assert residual.shape == direction.shape if layer_precedes_direction_layer_fn(layer_index): return torch.einsum("td,td->t", residual, direction)[ :, None ] # sum over residual stream channels else: return torch.zeros_like(residual[:, 0:1]) return inner_product_with_residual case NodeType.MLP_NEURON | NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT: assert write_tensor_by_layer_index is not None def multiply_by_projection_to_direction( activations: torch.Tensor, direction: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> ( torch.Tensor ): # (num_sequence_tokens, num_activations [i.e. num_neurons, num_latents]) assert layer_index is not None assert layer_index in write_tensor_by_layer_index assert pass_type == PassType.FORWARD assert activations.ndim == direction.ndim == 2 if layer_precedes_direction_layer_fn(layer_index): write_projection = torch.einsum( "ao,to->ta", write_tensor_by_layer_index[layer_index], direction ) return activations * write_projection else: return torch.zeros_like(activations) return multiply_by_projection_to_direction case NodeType.V_CHANNEL: assert write_tensor_by_layer_index is not None def attn_write_to_residual_direction_tensor_calculate_derived_scalar_fn( attn_weighted_values: torch.Tensor, direction: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> ( torch.Tensor ): # (num_sequence_tokens, num_heads) or (num_sequence_tokens, num_attended_to_sequence_tokens, num_heads) assert layer_index is not None assert layer_index in write_tensor_by_layer_index assert pass_type == PassType.FORWARD # one or two token dimensions, one head dimension, one value channel dimension assert attn_weighted_values.ndim in {3, 4} if layer_precedes_direction_layer_fn(layer_index): return compute_attn_write_to_residual_direction_from_attn_weighted_values( attn_weighted_values=attn_weighted_values, residual_direction=direction, W_O=write_tensor_by_layer_index[layer_index], pass_type=pass_type, ) # TODO: consider splitting into two cases, once we have separate node_types else: return torch.zeros_like(attn_weighted_values[..., 0]) # sum over last dimension return attn_write_to_residual_direction_tensor_calculate_derived_scalar_fn case _: raise NotImplementedError( f"make_write_to_direction_tensor_fn not implemented for {node_type=}" ) def compute_attn_write_to_residual_direction_from_attn_weighted_values( attn_weighted_values: torch.Tensor, residual_direction: torch.Tensor, W_O: torch.Tensor, # hdo pass_type: PassType, ) -> torch.Tensor: assert ( pass_type == PassType.FORWARD ), "only forward pass implemented for now for attn write norm from weighted sum of values" if attn_weighted_values.ndim == 3: num_sequence_tokens, nheads, d_head = attn_weighted_values.shape else: assert attn_weighted_values.ndim == 4 ( num_sequence_tokens, num_attended_to_sequence_tokens, nheads, d_head, ) = attn_weighted_values.shape assert residual_direction.shape[0] == num_sequence_tokens _, d_model = residual_direction.shape assert W_O.shape == (nheads, d_head, d_model) W_O = W_O.to(residual_direction.dtype) Wo_projection = torch.einsum("hdo,to->thd", W_O, residual_direction) if attn_weighted_values.ndim == 3: v_times_Wo_projection = torch.einsum( "thd,thd->th", attn_weighted_values, Wo_projection ) # optionally either one or two token dimensions else: assert attn_weighted_values.ndim == 4 v_times_Wo_projection = torch.einsum( "tuhd,thd->tuh", attn_weighted_values, Wo_projection ) # optionally either one or two token dimensions assert (v_times_Wo_projection.shape[0], v_times_Wo_projection.shape[-1]) == ( num_sequence_tokens, nheads, ) return v_times_Wo_projection def convert_scalar_deriver_to_write_to_direction_with_write_tensor( scalar_deriver: ScalarDeriver, write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | dict[int, torch.Tensor] | None, direction_scalar_source: ScalarSource, output_dst: DerivedScalarType, ) -> ScalarDeriver: """Takes as input a scalar deriver for a scalar activation fully defining a write direction (e.g. MLP activation or autoencoder but not post-softmax attention) and a scalar deriver for a direction in the residual stream basis. Multiplies each activation by its associated write vector and projects to the direction of interest.""" if write_tensor_by_layer_index is not None: check_write_tensor_device_matches( scalar_deriver, write_tensor_by_layer_index, ) def derived_scalar_precedes_direction_layer(layer_index: LayerIndex) -> bool: return precedes_final_layer( final_residual_location_within_layer=direction_scalar_source.location_within_layer, final_residual_layer_index=direction_scalar_source.layer_index, derived_scalar_location_within_layer=scalar_deriver.location_within_layer, derived_scalar_layer_index=layer_index, ) write_to_direction_tensor_fn = make_write_to_direction_tensor_fn( node_type=scalar_deriver.dst.node_type, write_tensor_by_layer_index=write_tensor_by_layer_index, layer_precedes_direction_layer_fn=derived_scalar_precedes_direction_layer, ) return scalar_deriver.apply_layerwise_transform_fn_to_output_and_other_tensor( write_to_direction_tensor_fn, pass_type_to_transform=PassType.FORWARD, output_dst=output_dst, other_scalar_source=direction_scalar_source, ) def make_final_residual_grad_scalar_source( dst_config: DstConfig, use_backward_pass: bool, ) -> ScalarSource: """ Many DSTs depend on the residual stream gradient at the last point in the forward pass before the point from which the backward pass is run. There are two ways of deriving this residual stream gradient. Background on backward passes: By default, the backward pass is run starting from some scalar function of the transformer's output logits. In this case, the last relevant point in the forward pass is at the very last residual stream location in the network (pre- final layer norm). A backward pass can also be run from an arbitrary activation in the network. In this case, the last relevant point in the forward pass is at the residual stream location immediately preceding the layer index of the activation from which the backward pass is run (pre- layer norm for that layer). The DstConfig object specifies whether the backward pass is the default (trace_config=None) or from an activation (trace_config=TraceConfig(node_index=NodeIndex(),...)). Note that if all you care about for a particular DST is the gradient at the last point in the forward pass (i.e. the first point in the backward pass), then running the full backward pass is actually wasteful. If you need to compute gradients with respect to many different activations, it's best just to run the very first part of the backward pass if possible. This is what use_backward_pass=False does. Two ways of deriving the residual stream gradient: - use_backward_pass=True: assume a literal backward pass has been run, outside the DST setup, as specified by the DstConfig object. In this case, you can directly use the "raw" residual stream gradient at a location inferrable from dst_config.trace_config. - use_backward_pass=False: (specific to the case where trace_config is not None) do not make assumptions about the literal backward pass that has been run. Take the residual stream **activations** (the forward pass) at the location implied by dst_config.trace_config. Recompute the activation specified from those residual stream activations, and run a small backward pass on the activation, back to those residual stream activations. """ if use_backward_pass: return make_backward_pass_scalar_source_for_final_residual_grad(dst_config) else: return make_no_backward_pass_scalar_source_for_final_residual_grad(dst_config) def make_backward_pass_scalar_source_for_final_residual_grad( dst_config: DstConfig, ) -> ScalarSource: """Called by other make_scalar_deriver functions; not needed as a derived scalar on its own. Note that the dst_config is not used for the (temporary) ScalarDeriver that is returned by this function. This determines the config needed for a final_residual_grad scalar deriver, based on the config of the scalar deriver for the activation which will be multiplied by the final residual grad. """ if ( dst_config.trace_config is not None and dst_config.trace_config.node_type.is_autoencoder_latent ): autoencoder_dst = dst_config.get_autoencoder_dst(dst_config.trace_config.node_type) else: autoencoder_dst = None return make_backward_pass_scalar_source_for_final_residual_grad_helper( n_layers=dst_config.get_n_layers(), trace_config=dst_config.trace_config, autoencoder_dst=autoencoder_dst, ) def make_backward_pass_scalar_source_for_fake_final_residual_grad( dst_config: DstConfig, ) -> ScalarSource: """Called by other make_scalar_deriver functions; not needed as a derived scalar on its own. Note that the dst_config is not used for the (temporary) ScalarDeriver that is returned by this function. This determines the config needed for a final_fake_residual_grad scalar deriver, based on the config of the scalar deriver for the activation which will be multiplied by the final fake residual grad. The gradient is "fake" in the sense that a real backward pass is run from a later point in the network, but the gradient is assumed to be ablated such that a real gradient of interest can be computed at the residual stream immediately preceding the layer_index of dst_config.activation_index_for_fake_grad. """ assert dst_config.activation_index_for_fake_grad is not None if ( dst_config.trace_config is not None and dst_config.trace_config.node_type.is_autoencoder_latent ): autoencoder_dst = dst_config.get_autoencoder_dst(dst_config.trace_config.node_type) else: autoencoder_dst = None return make_backward_pass_scalar_source_for_final_residual_grad_helper( n_layers=dst_config.get_n_layers(), trace_config=TraceConfig.from_activation_index( activation_index=dst_config.activation_index_for_fake_grad ), autoencoder_dst=autoencoder_dst, ) def make_backward_pass_scalar_source_for_final_residual_grad_helper( n_layers: int, # total layers in model trace_config: TraceConfig | None, autoencoder_dst: DerivedScalarType | None, ) -> ScalarSource: """ Returns the location of the last residual stream location prior to the layer norm preceding the location from which .backward() is being computed """ # lazily avoid circular import from neuron_explainer.activations.derived_scalars.make_scalar_derivers import ( make_scalar_deriver, ) if trace_config is None: return RawScalarSource( activation_location_type=ActivationLocationType.RESID_POST_MLP, pass_type=PassType.BACKWARD, layer_indexer=ConstantLayerIndexer(n_layers - 1), ) else: layer_index = trace_config.layer_index assert layer_index is not None residual_dst = get_previous_residual_dst_for_node_type( node_type=trace_config.node_type, autoencoder_dst=autoencoder_dst, ) return DerivedScalarSource( scalar_deriver=make_scalar_deriver( residual_dst, DstConfig(layer_indices=[layer_index], derive_gradients=True) ), pass_type=PassType.BACKWARD, layer_indexer=ConstantLayerIndexer(layer_index), ) def convert_scalar_deriver_to_write_to_final_residual_grad( scalar_deriver: ScalarDeriver, output_dst: DerivedScalarType, use_existing_backward_pass_for_final_residual_grad: bool, ) -> ScalarDeriver: direction_scalar_source = make_final_residual_grad_scalar_source( scalar_deriver.dst_config, use_existing_backward_pass_for_final_residual_grad ) return convert_scalar_deriver_to_write_to_direction( scalar_deriver=scalar_deriver, direction_scalar_source=direction_scalar_source, output_dst=output_dst, ) def convert_scalar_deriver_to_write_to_direction( scalar_deriver: ScalarDeriver, direction_scalar_source: ScalarSource, output_dst: DerivedScalarType, ) -> ScalarDeriver: model_context = scalar_deriver.dst_config.get_model_context() layer_indices = scalar_deriver.dst_config.layer_indices or list(range(model_context.n_layers)) node_type = scalar_deriver.dst.node_type match node_type: case NodeType.RESIDUAL_STREAM_CHANNEL: write_tensor_by_layer_index: dict[LayerIndex, torch.Tensor] | None = None case NodeType.MLP_NEURON: write_tensor_by_layer_index = get_mlp_write_tensor_by_layer_index( model_context=model_context, layer_indices=layer_indices, ) case NodeType.V_CHANNEL: write_tensor_by_layer_index = get_attn_write_tensor_by_layer_index( model_context=model_context, layer_indices=layer_indices, ) case ( NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT ): autoencoder_context = scalar_deriver.dst_config.get_autoencoder_context(node_type) assert autoencoder_context is not None write_tensor_by_layer_index = get_autoencoder_write_tensor_by_layer_index( model_context=model_context, autoencoder_context=autoencoder_context, ) case _: raise NotImplementedError( f"convert_scalar_deriver_to_write_to_direction not implemented for {node_type=}" ) return convert_scalar_deriver_to_write_to_direction_with_write_tensor( scalar_deriver=scalar_deriver, write_tensor_by_layer_index=write_tensor_by_layer_index, direction_scalar_source=direction_scalar_source, output_dst=output_dst, ) def make_reconstituted_attention_direct_effect_fn( model_context: ModelContext, layer_indices: list[int] | None, detach_layer_norm_scale: bool, ) -> Callable[[torch.Tensor, torch.Tensor, LayerIndex, PassType], torch.Tensor]: apply_attn_V_act = make_apply_attn_V_act( transformer=model_context.get_or_create_model(), q_k_or_v=ActivationLocationType.ATTN_VALUE, detach_layer_norm_scale=detach_layer_norm_scale, ) write_tensor_by_layer_index = get_attn_write_tensor_by_layer_index( model_context=model_context, layer_indices=layer_indices, ) def direct_effect_fn( resid: torch.Tensor, grad: torch.Tensor, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: # grad is a d_model-dimensional vector attn, V = apply_attn_V_act(resid, layer_index, pass_type) attn_weighted_V = torch.einsum("qkh,khd->qkhd", attn, V) grad_proj_to_V = torch.einsum("hdo,qo->qhd", write_tensor_by_layer_index[layer_index], grad) # grad is w/r/t (attn_weighted_V summed over k, or ATTN_WEIGHTED_SUM_OF_VALUES) return torch.einsum("qkhd,qhd->qkh", attn_weighted_V, grad_proj_to_V) return direct_effect_fn class AttentionDirectEffectReconstituter(Reconstituter): """Reconstitute an attention head's write to a particular direction""" requires_other_scalar_source = True node_type = NodeType.ATTENTION_HEAD def __init__( self, model_context: ModelContext, layer_indices: list[int] | None, detach_layer_norm_scale: bool, ): super().__init__() self._reconstitute_activations_fn = make_reconstituted_attention_direct_effect_fn( model_context=model_context, layer_indices=layer_indices, detach_layer_norm_scale=detach_layer_norm_scale, ) self._layer_indices = layer_indices self.residual_dst = get_previous_residual_dst_for_node_type( node_type=self.node_type, autoencoder_dst=None, ) def reconstitute_activations( self, resid: torch.Tensor, grad: torch.Tensor | None, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert pass_type == PassType.FORWARD assert grad is not None return self._reconstitute_activations_fn( resid, grad, layer_index, pass_type, ) def make_other_scalar_source(self, dst_config: DstConfig) -> ScalarSource: # make_backward_pass_scalar_source_for_final_residual_grad # does not use most of the fields of dst_config; just # get_n_layers(), get_autoencoder_dst(), and trace_config return make_backward_pass_scalar_source_for_final_residual_grad(dst_config) def _check_node_index(self, node_index: NodeIndex) -> None: assert node_index.node_type == self.node_type assert node_index.pass_type == PassType.FORWARD assert node_index.layer_index is not None # self._layer_indices = None -> support all layer_indices; otherwise only a subset # of layer indices are loaded assert self._layer_indices is None or node_index.layer_index in self._layer_indices if isinstance(node_index, AttnSubNodeIndex): assert node_index.q_k_or_v == ActivationLocationType.ATTN_VALUE def make_scalar_hook_for_node_index( self, node_index: NodeIndex ) -> Callable[[torch.Tensor], torch.Tensor]: self._check_node_index(node_index) assert node_index.ndim == 0 def get_activation_from_layer_activations(layer_activations: torch.Tensor) -> torch.Tensor: return layer_activations[node_index.tensor_indices] return get_activation_from_layer_activations def make_gradient_scalar_deriver_for_node_index( self, node_index: NodeIndex, dst_config: DstConfig, output_dst: DerivedScalarType | None = None, ) -> ScalarDeriver: self._check_node_index(node_index) assert node_index.layer_index is not None dst_config_for_layer = dataclasses.replace( dst_config, layer_indices=[node_index.layer_index], ) scalar_hook = self.make_scalar_hook_for_node_index(node_index) return self.make_gradient_scalar_deriver( scalar_hook=scalar_hook, dst_config=dst_config_for_layer, output_dst=output_dst, ) def make_gradient_scalar_source_for_node_index( self, node_index: NodeIndex, dst_config: DstConfig, output_dst: DerivedScalarType | None = None, ) -> DerivedScalarSource: scalar_hook = self.make_scalar_hook_for_node_index(node_index) gradient_scalar_deriver = self.make_gradient_scalar_deriver( scalar_hook=scalar_hook, dst_config=dst_config, output_dst=output_dst, ) assert node_index.layer_index is not None return DerivedScalarSource( scalar_deriver=gradient_scalar_deriver, pass_type=PassType.FORWARD, layer_indexer=ConstantLayerIndexer(node_index.layer_index), )