neuron_explainer/activations/derived_scalars/edge_activation.py (63 lines of code) (raw):

"""This file defines ScalarDerivers for efficiently computing the direct effect of a single upstream node on many downstream nodes.""" from typing import Callable from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.node_write import make_node_write_scalar_source from neuron_explainer.activations.derived_scalars.reconstituter_class import ActivationReconstituter from neuron_explainer.activations.derived_scalars.scalar_deriver import ( DstConfig, ScalarDeriver, ScalarSource, ) from neuron_explainer.models.model_component_registry import ActivationLocationType from neuron_explainer.models.model_context import StandardModelContext def convert_node_write_scalar_deriver_to_in_edge_activation( node_write_scalar_source: ScalarSource, output_dst: DerivedScalarType, dst_config: DstConfig, downstream_activation_location_type: ActivationLocationType, downstream_q_or_k: ActivationLocationType | None, ) -> ScalarDeriver: """Converts a scalar deriver for a write vector from some upstream node type to a scalar deriver for in edge activation for downstream nodes of some type (MLP, autoencoder, or attention head). In the case of attention heads, this is split up by subnode (Q or K).""" model_context = dst_config.get_model_context() autoencoder_context = dst_config.get_autoencoder_context() assert isinstance(model_context, StandardModelContext) transformer = model_context.get_or_create_model() reconstituter = ActivationReconstituter.from_activation_location_type( transformer=transformer, autoencoder_context=autoencoder_context, activation_location_type=downstream_activation_location_type, q_or_k=downstream_q_or_k, ) return reconstituter.make_jvp_scalar_deriver( write_scalar_source=node_write_scalar_source, dst_config=dst_config, output_dst=output_dst, ) def make_in_edge_activation_scalar_deriver_factory( activation_location_type: ActivationLocationType, q_or_k: ActivationLocationType | None = None, ) -> Callable[[DstConfig], ScalarDeriver]: """Returns a function that creates a scalar deriver for the edge attribution from arbitrary node to the specified downstream activation location type / sub activation location type (MLP post act, autoencoder latent, attention head Q or K). """ sub_node_type_to_output_dst = { (ActivationLocationType.MLP_POST_ACT, None): DerivedScalarType.MLP_IN_EDGE_ACTIVATION, ( ActivationLocationType.ONLINE_AUTOENCODER_LATENT, None, ): DerivedScalarType.ONLINE_AUTOENCODER_IN_EDGE_ACTIVATION, ( ActivationLocationType.ATTN_QK_PROBS, ActivationLocationType.ATTN_QUERY, ): DerivedScalarType.ATTN_QUERY_IN_EDGE_ACTIVATION, ( ActivationLocationType.ATTN_QK_PROBS, ActivationLocationType.ATTN_KEY, ): DerivedScalarType.ATTN_KEY_IN_EDGE_ACTIVATION, } output_dst = sub_node_type_to_output_dst[(activation_location_type, q_or_k)] def make_in_edge_activation_scalar_deriver(dst_config: DstConfig) -> ScalarDeriver: node_write_scalar_source = make_node_write_scalar_source(dst_config) return convert_node_write_scalar_deriver_to_in_edge_activation( node_write_scalar_source=node_write_scalar_source, output_dst=output_dst, dst_config=dst_config, downstream_activation_location_type=activation_location_type, downstream_q_or_k=q_or_k, ) return make_in_edge_activation_scalar_deriver