in neuron_explainer/activations/derived_scalars/derived_scalar_store.py [0:0]
def _check_activation_ndims(self) -> None:
# ensure that the shapes for the activation tensors are consistent with the shape specs
# for the derived scalar types
for (
dst,
pass_type,
), activations_and_metadata in self.activations_and_metadata_by_dst_and_pass_type.items():
shape_spec = dst.shape_spec_per_token_sequence
for activations in activations_and_metadata.activations_by_layer_index.values():
assert activations.ndim == len(shape_spec), (
f"Expected activations to have ndim {len(shape_spec)}, but got {activations.shape=} "
f"for {dst=}, {pass_type=}"
)