in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
if node_index.node_type == NodeType.ATTENTION_HEAD:
# see _get_residual_stream_tensor_indices_for_node for more information
# TODO: finish supporting attention heads
assert isinstance(node_index, AttnSubNodeIndex), (
node_index.node_type,
type(node_index),
)
assert node_index.q_k_or_v in {
ActivationLocationType.ATTN_QUERY,
ActivationLocationType.ATTN_KEY,
}
dst_for_computing_grad = self._input_dst_by_node_type[node_index.node_type]
supported_dsts = list(self._input_dst_by_node_type.values())
assert dst_for_computing_grad in supported_dsts, (
f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
f"supported node types are {supported_dsts}"
)
updated_tensor_indices = _get_residual_stream_tensor_indices_for_node(node_index)
# note: derived scalar indices do not have q_k_or_v associated to them, so we remove this field
updated_node_index = NodeIndex(
node_type=dst_for_computing_grad.node_type,
# Remove the activation index; the entire residual stream will be needed for computing
# the gradient.
tensor_indices=updated_tensor_indices,
layer_index=node_index.layer_index,
pass_type=node_index.pass_type,
)
return DerivedScalarIndex.from_node_index(
updated_node_index,
dst_for_computing_grad,
)