def __getitem__()

in neuron_explainer/activations/derived_scalars/derived_scalar_store.py [0:0]


    def __getitem__(self, key: DerivedScalarIndex | list[DerivedScalarIndex]) -> torch.Tensor:
        # indexed by index within layer_indices
        if isinstance(key, list):
            items = [self.__getitem__(k) for k in key]
            if all(isinstance(item, torch.Tensor) for item in items):
                return torch.stack(items)
            else:
                assert all(
                    isinstance(item, float) for item in items
                ), f"Expected all items to be torch tensors or floats, but got {items}"
                return torch.tensor(items)
        else:
            layer_indices = self.sorted_layer_indices_by_dst_and_pass_type[(key.dst, key.pass_type)]
            layer_index = key.layer_index

            assert (
                layer_index in layer_indices
            ), f"Layer index {layer_index} not in layer_indices {layer_indices}"
            indices_for_tensor: tuple[slice | int | None, ...] = tuple(
                slice(None) if index is None else index for index in key.tensor_indices
            )
            tensor_for_layer = self.activations_and_metadata_by_dst_and_pass_type[
                (key.dst, key.pass_type)
            ].activations_by_layer_index[layer_index]
            assert key.dst is not None
            assert len(indices_for_tensor) <= tensor_for_layer.ndim, (
                f"Too many indices for tensor of shape {tensor_for_layer.shape} "
                f"and indices {indices_for_tensor}; "
                f"{key.dst=}, {key.pass_type=}, {key.layer_index=}"
            )
            return tensor_for_layer[indices_for_tensor]