in neuron_explainer/activations/derived_scalars/postprocessing.py [0:0]
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
dst_for_write = self._input_dst_by_node_type[node_index.node_type]
supported_dsts = list(self._input_dst_by_node_type.values())
assert dst_for_write in supported_dsts, (
f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
f"supported node types are {supported_dsts}"
)
if node_index.node_type == NodeType.LAYER:
# remove the final, singleton dimension, which is not in the converted derived scalar type
assert len(node_index.tensor_indices) == 2
assert node_index.tensor_indices[1] == 0
updated_tensor_indices: tuple[int | None, ...] = node_index.tensor_indices[:-1]
else:
updated_tensor_indices = node_index.tensor_indices
ds_index = DerivedScalarIndex.from_node_index(
node_index.with_updates(
node_type=dst_for_write.node_type, tensor_indices=updated_tensor_indices
),
dst_for_write,
)
return ds_index