def convert_node_index_to_ds_index()

in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]


    def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
        dst_for_write = self._input_dst_by_node_type[node_index.node_type]
        supported_dsts = list(self._input_dst_by_node_type.values())
        assert dst_for_write in supported_dsts, (
            f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
            f"supported node types are {supported_dsts}"
        )
        if node_index.node_type == NodeType.LAYER:
            # remove the final, singleton dimension, which is not in the converted derived scalar type
            assert len(node_index.tensor_indices) == 2
            assert node_index.tensor_indices[1] == 0
            updated_tensor_indices: tuple[int | None, ...] = node_index.tensor_indices[:-1]
        else:
            updated_tensor_indices = node_index.tensor_indices
        ds_index = DerivedScalarIndex.from_node_index(
            node_index.with_updates(
                node_type=dst_for_write.node_type, tensor_indices=updated_tensor_indices
            ),
            dst_for_write,
        )
        return ds_index