in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]
def _get_residual_stream_tensor_indices_for_node(node_index: NodeIndex) -> tuple[int]:
"""For a given node index defining a point from which the gradient will be computed, this identifies the token
indices at which the gradient immediately before the node will be nonzero. For attention, in order for there to
be exactly one such token index, the gradient is computed through one of query/key/value, with a stopgrad
through the others. Depending on which of query/key/value is used, the token index will be either the query token
index or the key/value token index. For MLP neurons, the token index will be the token index of the neuron.
"""
# tensor_indices are expected to be tuple[int, ...], even if length 1
match node_index.node_type:
case NodeType.ATTENTION_HEAD:
# in the case of attention head reads, there are several possible ways to interpret the "read" direction
# - the gradient through just the query (at the query token)
# - the gradient through just the key (at the key/value token)
# - the gradient with respect to some function of the attention write, e.g. the attention write norm,
# through just the value (at the key/value token)
assert isinstance(node_index, AttnSubNodeIndex)
assert len(node_index.tensor_indices) == 3
match node_index.q_k_or_v:
case ActivationLocationType.ATTN_QUERY:
tensor_index = node_index.tensor_indices[1] # just the query token index
case ActivationLocationType.ATTN_KEY | ActivationLocationType.ATTN_VALUE:
tensor_index = node_index.tensor_indices[0] # just the key/value token index
case _:
raise ValueError(f"Unexpected q_k_or_v: {node_index.q_k_or_v}")
case (
NodeType.MLP_NEURON
| NodeType.AUTOENCODER_LATENT
| NodeType.MLP_AUTOENCODER_LATENT
| NodeType.ATTENTION_AUTOENCODER_LATENT
):
assert len(node_index.tensor_indices) == 2
tensor_index = node_index.tensor_indices[0] # just the token index
case _:
raise ValueError(f"Node type {node_index.node_type} not supported")
assert isinstance(tensor_index, int), (tensor_index, type(tensor_index))
return (tensor_index,)