def _get_residual_stream_tensor_indices_for_node()

in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]


def _get_residual_stream_tensor_indices_for_node(node_index: NodeIndex) -> tuple[int]:
    """For a given node index defining a point from which the gradient will be computed, this identifies the token
    indices at which the gradient immediately before the node will be nonzero. For attention, in order for there to
    be exactly one such token index, the gradient is computed through one of query/key/value, with a stopgrad
    through the others. Depending on which of query/key/value is used, the token index will be either the query token
    index or the key/value token index. For MLP neurons, the token index will be the token index of the neuron.
    """
    # tensor_indices are expected to be tuple[int, ...], even if length 1
    match node_index.node_type:
        case NodeType.ATTENTION_HEAD:
            # in the case of attention head reads, there are several possible ways to interpret the "read" direction
            # - the gradient through just the query (at the query token)
            # - the gradient through just the key (at the key/value token)
            # - the gradient with respect to some function of the attention write, e.g. the attention write norm,
            # through just the value (at the key/value token)
            assert isinstance(node_index, AttnSubNodeIndex)
            assert len(node_index.tensor_indices) == 3
            match node_index.q_k_or_v:
                case ActivationLocationType.ATTN_QUERY:
                    tensor_index = node_index.tensor_indices[1]  # just the query token index
                case ActivationLocationType.ATTN_KEY | ActivationLocationType.ATTN_VALUE:
                    tensor_index = node_index.tensor_indices[0]  # just the key/value token index
                case _:
                    raise ValueError(f"Unexpected q_k_or_v: {node_index.q_k_or_v}")
        case (
            NodeType.MLP_NEURON
            | NodeType.AUTOENCODER_LATENT
            | NodeType.MLP_AUTOENCODER_LATENT
            | NodeType.ATTENTION_AUTOENCODER_LATENT
        ):
            assert len(node_index.tensor_indices) == 2
            tensor_index = node_index.tensor_indices[0]  # just the token index
        case _:
            raise ValueError(f"Node type {node_index.node_type} not supported")
    assert isinstance(tensor_index, int), (tensor_index, type(tensor_index))
    return (tensor_index,)