def convert_node_index_to_ds_index()

in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]


    def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
        if node_index.node_type == NodeType.ATTENTION_HEAD:
            # see _get_residual_stream_tensor_indices_for_node for more information
            # TODO: finish supporting attention heads
            assert isinstance(node_index, AttnSubNodeIndex), (
                node_index.node_type,
                type(node_index),
            )
            assert node_index.q_k_or_v in {
                ActivationLocationType.ATTN_QUERY,
                ActivationLocationType.ATTN_KEY,
            }

        dst_for_computing_grad = self._input_dst_by_node_type[node_index.node_type]
        supported_dsts = list(self._input_dst_by_node_type.values())
        assert dst_for_computing_grad in supported_dsts, (
            f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
            f"supported node types are {supported_dsts}"
        )
        updated_tensor_indices = _get_residual_stream_tensor_indices_for_node(node_index)
        # note: derived scalar indices do not have q_k_or_v associated to them, so we remove this field
        updated_node_index = NodeIndex(
            node_type=dst_for_computing_grad.node_type,
            # Remove the activation index; the entire residual stream will be needed for computing
            # the gradient.
            tensor_indices=updated_tensor_indices,
            layer_index=node_index.layer_index,
            pass_type=node_index.pass_type,
        )
        return DerivedScalarIndex.from_node_index(
            updated_node_index,
            dst_for_computing_grad,
        )