neuron_explainer/activations/derived_scalars/make_scalar_derivers.py (340 lines of code) (raw):

""" This file contains functions for generating a ScalarDeriver based on a DerivedScalarType and a DerivedScalarTypeConfig (`make_scalar_deriver`) or for convenience based on just a HookLocationType (`make_scalar_deriver_for_hook_location_type`). It calls make_scalar_deriver_... functions defined in other files within derived_scalars/. """ from typing import Callable from neuron_explainer.activations.derived_scalars.attention import ( make_attn_act_times_grad_per_sequence_token_scalar_deriver, make_attn_weighted_value_scalar_deriver, make_attn_write_norm_per_sequence_token_scalar_deriver, make_attn_write_norm_scalar_deriver, make_attn_write_scalar_deriver, make_attn_write_sum_heads_scalar_deriver, make_attn_write_to_final_residual_grad_per_sequence_token_scalar_deriver, make_attn_write_to_latent_per_sequence_token_batched_scalar_deriver, make_attn_write_to_latent_per_sequence_token_scalar_deriver, make_attn_write_to_latent_scalar_deriver, make_attn_write_to_latent_summed_over_heads_scalar_deriver, make_flattened_attn_post_softmax_act_times_grad_scalar_deriver, make_flattened_attn_post_softmax_scalar_deriver, make_flattened_attn_write_to_final_residual_grad_scalar_deriver, make_flattened_attn_write_to_latent_summed_over_heads_batched_scalar_deriver, make_flattened_attn_write_to_latent_summed_over_heads_scalar_deriver, make_unflattened_attn_write_norm_scalar_deriver, make_unflattened_attn_write_to_final_activation_residual_grad_scalar_deriver, make_unflattened_attn_write_to_final_residual_grad_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.autoencoder import ( make_autoencoder_latent_grad_wrt_mlp_post_act_input_scalar_deriver, make_autoencoder_latent_grad_wrt_residual_input_scalar_deriver, make_autoencoder_latent_scalar_deriver_factory, make_autoencoder_write_norm_scalar_deriver_factory, make_online_autoencoder_act_times_grad_scalar_deriver_factory, make_online_autoencoder_error_scalar_deriver_factory, make_online_autoencoder_latent_scalar_deriver_factory, make_online_autoencoder_latentwise_write_scalar_deriver_factory, make_online_autoencoder_write_norm_scalar_deriver_factory, make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver_factory, make_online_autoencoder_write_to_final_residual_grad_scalar_deriver_factory, make_online_mlp_autoencoder_error_act_times_grad_scalar_deriver, make_online_mlp_autoencoder_error_write_norm_scalar_deriver, make_online_mlp_autoencoder_error_write_to_final_residual_grad_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.config import DstConfig from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.edge_activation import ( make_in_edge_activation_scalar_deriver_factory, ) from neuron_explainer.activations.derived_scalars.edge_attribution import ( make_attn_out_edge_attribution_scalar_deriver, make_grad_of_downstream_subnode_attribution_scalar_deriver, make_in_edge_attribution_scalar_deriver_factory, make_mlp_out_edge_attribution_scalar_deriver, make_node_write_scalar_deriver, make_node_write_to_final_residual_grad_scalar_deriver, make_online_autoencoder_out_edge_attribution_scalar_deriver, make_token_out_edge_attribution_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.mlp import ( make_mlp_neuronwise_write_scalar_deriver, make_mlp_write_norm_scalar_deriver, make_mlp_write_to_final_activation_residual_grad_scalar_deriver, make_mlp_write_to_final_residual_grad_scalar_deriver, make_resid_delta_mlp_from_mlp_post_act_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.raw_activations import ( make_scalar_deriver_factory_for_act_times_grad, make_scalar_deriver_factory_for_activation_location_type, make_truncate_to_expected_shape_scalar_deriver_factory_for_dst, ) from neuron_explainer.activations.derived_scalars.residual import ( make_previous_layer_resid_post_mlp_scalar_deriver, make_residual_norm_scalar_deriver_factory_for_activation_location_type, make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type, make_token_attribution_scalar_deriver, make_unity_scalar_deriver, make_vocab_token_write_to_input_direction_scalar_deriver, ) from neuron_explainer.activations.derived_scalars.scalar_deriver import ScalarDeriver from neuron_explainer.models.model_component_registry import ActivationLocationType, NodeType ### REGISTRY; ADD NEW TYPES HERE, AND ALSO IN ENUM IN scalar_deriver.py ### # This contains a function to generate each implemented derived scalar type. Called by # make_scalar_deriver below. _DERIVED_SCALAR_TYPE_REGISTRY: dict[DerivedScalarType, Callable[[DstConfig], ScalarDeriver]] = { DerivedScalarType.RESID_POST_EMBEDDING: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_EMBEDDING ), DerivedScalarType.RESID_POST_MLP: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_MLP ), DerivedScalarType.RESID_POST_ATTN: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_ATTN ), DerivedScalarType.RESID_FINAL_LAYER_NORM_SCALE: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_FINAL_LAYER_NORM_SCALE ), DerivedScalarType.ATTN_INPUT_LAYER_NORM_SCALE: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_INPUT_LAYER_NORM_SCALE ), DerivedScalarType.MLP_INPUT_LAYER_NORM_SCALE: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.MLP_INPUT_LAYER_NORM_SCALE ), DerivedScalarType.LOGITS: make_truncate_to_expected_shape_scalar_deriver_factory_for_dst( DerivedScalarType.LOGITS ), DerivedScalarType.MLP_PRE_ACT: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.MLP_PRE_ACT ), DerivedScalarType.MLP_POST_ACT: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.MLP_POST_ACT ), DerivedScalarType.ATTN_QUERY: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_QUERY ), DerivedScalarType.ATTN_KEY: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_KEY ), DerivedScalarType.ATTN_VALUE: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_VALUE ), DerivedScalarType.ATTN_QK_LOGITS: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_QK_LOGITS ), DerivedScalarType.ATTN_QK_PROBS: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_QK_PROBS ), DerivedScalarType.ATTN_WEIGHTED_SUM_OF_VALUES: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.ATTN_WEIGHTED_SUM_OF_VALUES ), DerivedScalarType.RESID_DELTA_ATTN: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_ATTN ), DerivedScalarType.ATTN_WRITE_NORM: make_attn_write_norm_scalar_deriver, DerivedScalarType.FLATTENED_ATTN_POST_SOFTMAX: make_flattened_attn_post_softmax_scalar_deriver, DerivedScalarType.ATTN_ACT_TIMES_GRAD: make_flattened_attn_post_softmax_act_times_grad_scalar_deriver, DerivedScalarType.RESID_DELTA_MLP: make_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_MLP, ), DerivedScalarType.RESID_DELTA_MLP_FROM_MLP_POST_ACT: make_resid_delta_mlp_from_mlp_post_act_scalar_deriver, DerivedScalarType.MLP_WRITE_NORM: make_mlp_write_norm_scalar_deriver, DerivedScalarType.MLP_ACT_TIMES_GRAD: make_scalar_deriver_factory_for_act_times_grad( ActivationLocationType.MLP_POST_ACT, DerivedScalarType.MLP_ACT_TIMES_GRAD, ), DerivedScalarType.MLP_WRITE_TO_FINAL_RESIDUAL_GRAD: make_mlp_write_to_final_residual_grad_scalar_deriver, DerivedScalarType.ATTN_WRITE_NORM_PER_SEQUENCE_TOKEN: make_attn_write_norm_per_sequence_token_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD_PER_SEQUENCE_TOKEN: make_attn_write_to_final_residual_grad_per_sequence_token_scalar_deriver, DerivedScalarType.ATTN_ACT_TIMES_GRAD_PER_SEQUENCE_TOKEN: make_attn_act_times_grad_per_sequence_token_scalar_deriver, DerivedScalarType.RESID_POST_EMBEDDING_NORM: make_residual_norm_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_EMBEDDING ), DerivedScalarType.RESID_POST_MLP_NORM: make_residual_norm_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_MLP ), DerivedScalarType.RESID_POST_ATTN_NORM: make_residual_norm_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_ATTN ), DerivedScalarType.MLP_LAYER_WRITE_NORM: make_residual_norm_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_MLP ), DerivedScalarType.ATTN_LAYER_WRITE_NORM: make_residual_norm_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_ATTN ), DerivedScalarType.RESID_POST_EMBEDDING_PROJ_TO_FINAL_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_EMBEDDING, use_existing_backward_pass_for_final_residual_grad=True, ), DerivedScalarType.RESID_POST_MLP_PROJ_TO_FINAL_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_MLP, use_existing_backward_pass_for_final_residual_grad=True, ), DerivedScalarType.RESID_POST_ATTN_PROJ_TO_FINAL_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_ATTN, use_existing_backward_pass_for_final_residual_grad=True, ), DerivedScalarType.MLP_LAYER_WRITE_TO_FINAL_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_MLP, use_existing_backward_pass_for_final_residual_grad=True, ), DerivedScalarType.ATTN_LAYER_WRITE_TO_FINAL_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_DELTA_ATTN, use_existing_backward_pass_for_final_residual_grad=True, ), DerivedScalarType.UNFLATTENED_ATTN_ACT_TIMES_GRAD: make_scalar_deriver_factory_for_act_times_grad( ActivationLocationType.ATTN_QK_PROBS, DerivedScalarType.UNFLATTENED_ATTN_ACT_TIMES_GRAD, ), DerivedScalarType.UNFLATTENED_ATTN_WRITE_NORM: make_unflattened_attn_write_norm_scalar_deriver, DerivedScalarType.UNFLATTENED_ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD: make_unflattened_attn_write_to_final_residual_grad_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD: make_flattened_attn_write_to_final_residual_grad_scalar_deriver, DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR: make_online_autoencoder_error_scalar_deriver_factory( ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR ), DerivedScalarType.ONLINE_RESIDUAL_MLP_AUTOENCODER_ERROR: make_online_autoencoder_error_scalar_deriver_factory( ActivationLocationType.ONLINE_RESIDUAL_MLP_AUTOENCODER_ERROR ), DerivedScalarType.ONLINE_RESIDUAL_ATTENTION_AUTOENCODER_ERROR: make_online_autoencoder_error_scalar_deriver_factory( ActivationLocationType.ONLINE_RESIDUAL_ATTENTION_AUTOENCODER_ERROR ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_ACT_TIMES_GRAD: make_online_mlp_autoencoder_error_act_times_grad_scalar_deriver, DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_WRITE_NORM: make_online_mlp_autoencoder_error_write_norm_scalar_deriver, DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_WRITE_TO_FINAL_RESIDUAL_GRAD: make_online_mlp_autoencoder_error_write_to_final_residual_grad_scalar_deriver, DerivedScalarType.ATTN_WRITE: make_attn_write_scalar_deriver, DerivedScalarType.ATTN_WRITE_SUM_HEADS: make_attn_write_sum_heads_scalar_deriver, DerivedScalarType.MLP_WRITE: make_mlp_neuronwise_write_scalar_deriver, DerivedScalarType.ATTN_WEIGHTED_VALUE: make_attn_weighted_value_scalar_deriver, DerivedScalarType.PREVIOUS_LAYER_RESID_POST_MLP: make_previous_layer_resid_post_mlp_scalar_deriver, DerivedScalarType.MLP_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_mlp_write_to_final_activation_residual_grad_scalar_deriver, DerivedScalarType.UNFLATTENED_ATTN_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_unflattened_attn_write_to_final_activation_residual_grad_scalar_deriver, DerivedScalarType.RESID_POST_EMBEDDING_PROJ_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_residual_projection_to_final_residual_grad_scalar_deriver_factory_for_activation_location_type( ActivationLocationType.RESID_POST_EMBEDDING, use_existing_backward_pass_for_final_residual_grad=False, ), DerivedScalarType.AUTOENCODER_LATENT_GRAD_WRT_RESIDUAL_INPUT: make_autoencoder_latent_grad_wrt_residual_input_scalar_deriver, DerivedScalarType.AUTOENCODER_LATENT_GRAD_WRT_MLP_POST_ACT_INPUT: make_autoencoder_latent_grad_wrt_mlp_post_act_input_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_LATENT: make_attn_write_to_latent_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS: make_attn_write_to_latent_summed_over_heads_scalar_deriver, DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS: make_flattened_attn_write_to_latent_summed_over_heads_scalar_deriver, DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS_BATCHED: make_flattened_attn_write_to_latent_summed_over_heads_batched_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_LATENT_PER_SEQUENCE_TOKEN: make_attn_write_to_latent_per_sequence_token_scalar_deriver, DerivedScalarType.ATTN_WRITE_TO_LATENT_PER_SEQUENCE_TOKEN_BATCHED: make_attn_write_to_latent_per_sequence_token_batched_scalar_deriver, DerivedScalarType.TOKEN_ATTRIBUTION: make_token_attribution_scalar_deriver, DerivedScalarType.SINGLE_NODE_WRITE: make_node_write_scalar_deriver, DerivedScalarType.GRAD_OF_SINGLE_SUBNODE_ATTRIBUTION: make_grad_of_downstream_subnode_attribution_scalar_deriver, DerivedScalarType.ATTN_OUT_EDGE_ATTRIBUTION: make_attn_out_edge_attribution_scalar_deriver, DerivedScalarType.MLP_OUT_EDGE_ATTRIBUTION: make_mlp_out_edge_attribution_scalar_deriver, DerivedScalarType.ONLINE_AUTOENCODER_OUT_EDGE_ATTRIBUTION: make_online_autoencoder_out_edge_attribution_scalar_deriver, DerivedScalarType.ATTN_QUERY_IN_EDGE_ATTRIBUTION: make_in_edge_attribution_scalar_deriver_factory( NodeType.ATTENTION_HEAD, ActivationLocationType.ATTN_QUERY ), DerivedScalarType.ATTN_KEY_IN_EDGE_ATTRIBUTION: make_in_edge_attribution_scalar_deriver_factory( NodeType.ATTENTION_HEAD, ActivationLocationType.ATTN_KEY ), DerivedScalarType.ATTN_VALUE_IN_EDGE_ATTRIBUTION: make_in_edge_attribution_scalar_deriver_factory( NodeType.ATTENTION_HEAD, ActivationLocationType.ATTN_VALUE ), DerivedScalarType.MLP_IN_EDGE_ATTRIBUTION: make_in_edge_attribution_scalar_deriver_factory( NodeType.MLP_NEURON ), DerivedScalarType.ONLINE_AUTOENCODER_IN_EDGE_ATTRIBUTION: make_in_edge_attribution_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.SINGLE_NODE_WRITE_TO_FINAL_RESIDUAL_GRAD: make_node_write_to_final_residual_grad_scalar_deriver, DerivedScalarType.TOKEN_OUT_EDGE_ATTRIBUTION: make_token_out_edge_attribution_scalar_deriver, DerivedScalarType.VOCAB_TOKEN_WRITE_TO_INPUT_DIRECTION: make_vocab_token_write_to_input_direction_scalar_deriver, DerivedScalarType.ALWAYS_ONE: make_unity_scalar_deriver, DerivedScalarType.ATTN_QUERY_IN_EDGE_ACTIVATION: make_in_edge_activation_scalar_deriver_factory( ActivationLocationType.ATTN_QK_PROBS, ActivationLocationType.ATTN_QUERY ), DerivedScalarType.ATTN_KEY_IN_EDGE_ACTIVATION: make_in_edge_activation_scalar_deriver_factory( ActivationLocationType.ATTN_QK_PROBS, ActivationLocationType.ATTN_KEY ), DerivedScalarType.MLP_IN_EDGE_ACTIVATION: make_in_edge_activation_scalar_deriver_factory( ActivationLocationType.MLP_POST_ACT, ), DerivedScalarType.ONLINE_AUTOENCODER_IN_EDGE_ACTIVATION: make_in_edge_activation_scalar_deriver_factory( ActivationLocationType.ONLINE_AUTOENCODER_LATENT, ), DerivedScalarType.AUTOENCODER_LATENT: make_autoencoder_latent_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.MLP_AUTOENCODER_LATENT: make_autoencoder_latent_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ATTENTION_AUTOENCODER_LATENT: make_autoencoder_latent_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_LATENT: make_online_autoencoder_latent_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT: make_online_autoencoder_latent_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT: make_online_autoencoder_latent_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD: make_online_autoencoder_act_times_grad_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_ACT_TIMES_GRAD: make_online_autoencoder_act_times_grad_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_ACT_TIMES_GRAD: make_online_autoencoder_act_times_grad_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_residual_grad_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_residual_grad_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_residual_grad_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_WRITE: make_online_autoencoder_latentwise_write_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE: make_online_autoencoder_latentwise_write_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE: make_online_autoencoder_latentwise_write_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM: make_online_autoencoder_write_norm_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_NORM: make_online_autoencoder_write_norm_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_NORM: make_online_autoencoder_write_norm_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), DerivedScalarType.AUTOENCODER_WRITE_NORM: make_autoencoder_write_norm_scalar_deriver_factory( NodeType.AUTOENCODER_LATENT ), DerivedScalarType.MLP_AUTOENCODER_WRITE_NORM: make_autoencoder_write_norm_scalar_deriver_factory( NodeType.MLP_AUTOENCODER_LATENT ), DerivedScalarType.ATTENTION_AUTOENCODER_WRITE_NORM: make_autoencoder_write_norm_scalar_deriver_factory( NodeType.ATTENTION_AUTOENCODER_LATENT ), } def make_scalar_deriver( dst: DerivedScalarType, dst_config: DstConfig, ) -> ScalarDeriver: """The model name and layer indices of interest might or might not need to be specified based on the dst. In particular, if the dst is also a HookLocationType, then the model name and layer indices are not needed.""" assert dst in _DERIVED_SCALAR_TYPE_REGISTRY, f"Unknown {dst=}" # this is derived from one or more HookLocationTypes, via the function # specified in the registry make_scalar_deriver_fn = _DERIVED_SCALAR_TYPE_REGISTRY[dst] return make_scalar_deriver_fn(dst_config) def make_scalar_deriver_for_activation_location_type( activation_location_type: ActivationLocationType, derive_gradients: bool = False, ) -> ScalarDeriver: return make_scalar_deriver( DerivedScalarType.from_activation_location_type(activation_location_type), dst_config=DstConfig( derive_gradients=derive_gradients, ), )