in neuron_explainer/activations/derived_scalars/derived_scalar_store.py [0:0]
def __getitem__(self, key: DerivedScalarIndex | list[DerivedScalarIndex]) -> torch.Tensor:
# indexed by index within layer_indices
if isinstance(key, list):
items = [self.__getitem__(k) for k in key]
if all(isinstance(item, torch.Tensor) for item in items):
return torch.stack(items)
else:
assert all(
isinstance(item, float) for item in items
), f"Expected all items to be torch tensors or floats, but got {items}"
return torch.tensor(items)
else:
layer_indices = self.sorted_layer_indices_by_dst_and_pass_type[(key.dst, key.pass_type)]
layer_index = key.layer_index
assert (
layer_index in layer_indices
), f"Layer index {layer_index} not in layer_indices {layer_indices}"
indices_for_tensor: tuple[slice | int | None, ...] = tuple(
slice(None) if index is None else index for index in key.tensor_indices
)
tensor_for_layer = self.activations_and_metadata_by_dst_and_pass_type[
(key.dst, key.pass_type)
].activations_by_layer_index[layer_index]
assert key.dst is not None
assert len(indices_for_tensor) <= tensor_for_layer.ndim, (
f"Too many indices for tensor of shape {tensor_for_layer.shape} "
f"and indices {indices_for_tensor}; "
f"{key.dst=}, {key.pass_type=}, {key.layer_index=}"
)
return tensor_for_layer[indices_for_tensor]