def _check_activation_ndims()

in neuron_explainer/activations/derived_scalars/derived_scalar_store.py [0:0]


    def _check_activation_ndims(self) -> None:
        # ensure that the shapes for the activation tensors are consistent with the shape specs
        # for the derived scalar types
        for (
            dst,
            pass_type,
        ), activations_and_metadata in self.activations_and_metadata_by_dst_and_pass_type.items():
            shape_spec = dst.shape_spec_per_token_sequence
            for activations in activations_and_metadata.activations_by_layer_index.values():
                assert activations.ndim == len(shape_spec), (
                    f"Expected activations to have ndim {len(shape_spec)}, but got {activations.shape=} "
                    f"for {dst=}, {pass_type=}"
                )