neuron_explainer/activation_server/tdb_conversions.py (279 lines of code) (raw):

# Code for converting between client-friendly TDB request/response dataclasses and internal # representations used during request processing. from typing import TypeVar from neuron_explainer.activation_server.requests_and_responses import * from neuron_explainer.activation_server.requests_and_responses import GroupId from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import ( DETACH_LAYER_NORM_SCALE, AblationSpec, AttentionTraceType, MirroredActivationIndex, MirroredNodeIndex, MirroredTraceConfig, NodeAblation, NodeToTrace, PreOrPostAct, TraceConfig, ) from neuron_explainer.models.model_component_registry import ( ActivationLocationType, Dimension, NodeType, PassType, ) T = TypeVar("T") def convert_tdb_request_spec_to_inference_sub_request( tdb_request_spec: TdbRequestSpec, ) -> InferenceSubRequest: """ The client sends a TdbRequestSpec, but internally we do all processing in terms of InferenceSubRequests. This function converts from the client representation to the server representation. """ loss_fn_config: LossFnConfig | None = LossFnConfig( name=LossFnName.LOGIT_DIFF, target_tokens=tdb_request_spec.target_tokens, distractor_tokens=tdb_request_spec.distractor_tokens, ) ablation_specs = [ node_ablation_to_ablation_spec(ablation) for ablation in (tdb_request_spec.node_ablations or []) ] + [ AblationSpec( index=MirroredActivationIndex( activation_location_type=ActivationLocationType.RESID_FINAL_LAYER_NORM_SCALE, pass_type=PassType.BACKWARD, tensor_indices=("All", "All"), # ablate at all positions in the sequence layer_index=None, ), value=0, ) ] current_token_index = -1 trace_config = None if tdb_request_spec.upstream_node_to_trace is None: assert tdb_request_spec.downstream_node_to_trace is None else: ( trace_config, trace_token_index, ) = nodes_to_trace_to_trace_config( tdb_request_spec.upstream_node_to_trace, tdb_request_spec.downstream_node_to_trace ) if trace_token_index is not None: current_token_index = trace_token_index if trace_config is None: # not tracing -> DO compute loss pass elif trace_config.attention_trace_type == AttentionTraceType.V: # tracing attention through V if trace_config.downstream_trace_config is None: pass # tracing through V with no downstream trace -> DO compute loss else: loss_fn_config = ( None # tracing through V, but also with downstream trace -> DON'T compute loss ) else: loss_fn_config = ( None # tracing something other than attention through V -> DON'T compute loss ) inference_request_spec = InferenceRequestSpec( prompt=tdb_request_spec.prompt, loss_fn_config=loss_fn_config, ablation_specs=ablation_specs, trace_config=MirroredTraceConfig.from_trace_config(trace_config) if trace_config else None, ) spec_by_component_for_top_k = MultipleTopKDerivedScalarsRequestSpec( token_index=None, dst_list_by_group_id=make_grouped_dsts_per_component( tdb_request_spec.component_type_for_mlp, tdb_request_spec.component_type_for_attention, ), top_and_bottom_k=tdb_request_spec.top_and_bottom_k_for_node_table, ) spec_by_component_always_mlp_for_token_display = MultipleTopKDerivedScalarsRequestSpec( token_index=None, # the response to this request is to be used for summarizing the effects of entire # attention and MLP layers per token; thus, using a different basis for the activations # within a layer is not helpful, and we use MLP activations themselves. dst_list_by_group_id=make_grouped_dsts_per_component( ComponentTypeForMlp.NEURON, ComponentTypeForAttention.ATTENTION_HEAD ), top_and_bottom_k=1, dimensions_to_keep_for_intermediate_sum=[ Dimension.SEQUENCE_TOKENS, Dimension.ATTENDED_TO_SEQUENCE_TOKENS, ], ) def scored_tokens_request_spec( token_scoring_type: TokenScoringType, ) -> ScoredTokensRequestSpec: return ScoredTokensRequestSpec( token_scoring_type=token_scoring_type, num_tokens=10, # Our scored tokens requests are associated with the "topKComponents" request spec. This # means that they use the same node indices, DSTs and DST configs. depends_on_spec_name="topKComponents", ) def token_pair_attribution_request_spec() -> TokenPairAttributionRequestSpec: return TokenPairAttributionRequestSpec( num_tokens_attended_to=3, # Our scored tokens requests are associated with the "topKComponents" request spec. This # means that they use the same node indices, DSTs and DST configs. depends_on_spec_name="topKComponents", ) processing_request_spec_by_name: dict[str, ProcessingRequestSpec] = { "topKComponents": spec_by_component_for_top_k, "componentSumsForTokenDisplay": spec_by_component_always_mlp_for_token_display, # It's important for these request specs to come after the "topKComponents" request spec, # since they depend on data generated for that request spec. "upvotedOutputTokens": scored_tokens_request_spec(TokenScoringType.UPVOTED_OUTPUT_TOKENS), "inputTokensThatUpvoteMlp": scored_tokens_request_spec( TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_MLP ), "inputTokensThatUpvoteAttnQ": scored_tokens_request_spec( TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_Q ), "inputTokensThatUpvoteAttnK": scored_tokens_request_spec( TokenScoringType.INPUT_TOKENS_THAT_UPVOTE_ATTN_K ), "tokenPairAttribution": token_pair_attribution_request_spec(), } spec_by_vocab_token = MultipleTopKDerivedScalarsRequestSpec( dst_list_by_group_id={GroupId.LOGITS: [DerivedScalarType.LOGITS]}, top_and_bottom_k=100, token_index=current_token_index, ) processing_request_spec_by_name["topOutputTokenLogits"] = spec_by_vocab_token return InferenceSubRequest( inference_request_spec=inference_request_spec, processing_request_spec_by_name=processing_request_spec_by_name, ) def node_ablation_to_ablation_spec(node_ablation: NodeAblation) -> AblationSpec: node_index = node_ablation.node_index value = node_ablation.value match node_index.node_type: case NodeType.ATTENTION_HEAD: activation_location_type = ActivationLocationType.ATTN_QK_PROBS indices = [ get_sequence_token_index(node_index), "All", get_activation_index(node_index), ] case NodeType.MLP_NEURON: activation_location_type = ActivationLocationType.MLP_POST_ACT indices = [ get_sequence_token_index(node_index), get_activation_index(node_index), ] case ( NodeType.AUTOENCODER_LATENT | NodeType.MLP_AUTOENCODER_LATENT | NodeType.ATTENTION_AUTOENCODER_LATENT ): from neuron_explainer.activations.derived_scalars.autoencoder import ( get_autoencoder_alt_from_node_type, ) activation_location_type = get_autoencoder_alt_from_node_type(node_index.node_type) indices = [ get_sequence_token_index(node_index), get_activation_index(node_index), ] case _: raise ValueError(f"Unknown node type {node_index.node_type}") return AblationSpec( index=MirroredActivationIndex( activation_location_type=activation_location_type, pass_type=PassType.FORWARD, # mypy has trouble understanding that all of the values that can be assigned to indices # match AllOrOneIndices. tensor_indices=indices, # type: ignore layer_index=node_index.layer_index, ), value=value, ) def get_sequence_token_index(node_index: MirroredNodeIndex) -> int: return assert_non_none(node_index.tensor_indices[0]) def get_activation_index(node_index: MirroredNodeIndex) -> int: return assert_non_none(node_index.tensor_indices[-1]) def assert_non_none(value: T | None) -> T: assert value is not None return value def make_grouped_dsts_per_component( component_type_for_mlp: ComponentTypeForMlp, component_type_for_attention: ComponentTypeForAttention, ) -> dict[GroupId, list[DerivedScalarType]]: # common dsts for all components dsts = { GroupId.WRITE_NORM: [ DerivedScalarType.RESID_POST_EMBEDDING_NORM, ], GroupId.ACT_TIMES_GRAD: [ DerivedScalarType.TOKEN_ATTRIBUTION, ], GroupId.DIRECTION_WRITE: [ DerivedScalarType.RESID_POST_EMBEDDING_PROJ_TO_FINAL_RESIDUAL_GRAD, ], GroupId.ACTIVATION: [ DerivedScalarType.ALWAYS_ONE, # the resid post embedding is considered to have an # "activation" of 1.0 at every position, for display purposes ], } match component_type_for_mlp: case ComponentTypeForMlp.NEURON: dsts[GroupId.WRITE_NORM].append(DerivedScalarType.MLP_WRITE_NORM) dsts[GroupId.ACT_TIMES_GRAD].append(DerivedScalarType.MLP_ACT_TIMES_GRAD) dsts[GroupId.DIRECTION_WRITE].append(DerivedScalarType.MLP_WRITE_TO_FINAL_RESIDUAL_GRAD) dsts[GroupId.ACTIVATION].append(DerivedScalarType.MLP_POST_ACT) case ComponentTypeForMlp.AUTOENCODER_LATENT: dsts[GroupId.WRITE_NORM].append(DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_NORM) dsts[GroupId.ACT_TIMES_GRAD].append( DerivedScalarType.ONLINE_MLP_AUTOENCODER_ACT_TIMES_GRAD ) dsts[GroupId.DIRECTION_WRITE].append( DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD ) dsts[GroupId.ACTIVATION].append(DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT) case _: raise ValueError(f"Unknown component type {component_type_for_mlp} in TdbRequestSpec") match component_type_for_attention: case ComponentTypeForAttention.ATTENTION_HEAD: dsts[GroupId.WRITE_NORM].append(DerivedScalarType.UNFLATTENED_ATTN_WRITE_NORM) dsts[GroupId.ACT_TIMES_GRAD].append(DerivedScalarType.UNFLATTENED_ATTN_ACT_TIMES_GRAD) dsts[GroupId.DIRECTION_WRITE].append( DerivedScalarType.UNFLATTENED_ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD ) dsts[GroupId.ACTIVATION].append(DerivedScalarType.ATTN_QK_PROBS) case ComponentTypeForAttention.AUTOENCODER_LATENT: dsts[GroupId.WRITE_NORM].append( DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_NORM ) dsts[GroupId.ACT_TIMES_GRAD].append( DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_ACT_TIMES_GRAD ) dsts[GroupId.DIRECTION_WRITE].append( DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD ) dsts[GroupId.ACTIVATION].append(DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT) case _: raise ValueError( f"Unknown component type {component_type_for_attention} in TdbRequestSpec" ) return dsts def nodes_to_trace_to_trace_config( upstream_node_to_trace: NodeToTrace, downstream_node_to_trace: NodeToTrace | None, ) -> tuple[TraceConfig, int | None]: node_index = upstream_node_to_trace.node_index attention_trace_type = upstream_node_to_trace.attention_trace_type if downstream_node_to_trace is None: downstream_trace_config = None else: # only trace through V admits a downstream node to trace assert attention_trace_type == AttentionTraceType.V # don't assign downstream trace_token_index to a variable, as it's not used downstream_trace_config, _ = nodes_to_trace_to_trace_config( downstream_node_to_trace, None, # treat the downstream node as the "upstream" node to trace ) # NOTE: downstream node must not trace through V trace_token_index = node_index.tensor_indices[0] return ( TraceConfig( node_index=node_index.to_node_index(), pre_or_post_act=PreOrPostAct.PRE, attention_trace_type=attention_trace_type, downstream_trace_config=downstream_trace_config, detach_layer_norm_scale=DETACH_LAYER_NORM_SCALE, ), trace_token_index, ) def named_attention_head_indices(node_index: MirroredNodeIndex) -> tuple[int, int, int]: if node_index.node_type != NodeType.ATTENTION_HEAD: raise ValueError("Incorrect nodeType for namedAttentionHeadIndices function") ( attended_from_token_index, attended_to_token_index, attention_head_index, ) = node_index.tensor_indices return ( assert_non_none(attended_from_token_index), assert_non_none(attended_to_token_index), assert_non_none(attention_head_index), )