def tensor_index_by_dim()

in neuron_explainer/activations/derived_scalars/indexing.py [0:0]


    def tensor_index_by_dim(self) -> dict[Dimension, AllOrOneIndex]:
        # copied from DerivedScalarIndex; TODO: ActivationIndex and DerivedScalarIndex inherit from a shared base class,
        # and perhaps likewise with DerivedScalarType and ActivationLocationType?
        tensor_indices_list = list(self.tensor_indices)
        assert len(tensor_indices_list) <= len(
            self.activation_location_type.shape_spec_per_token_sequence
        ), (
            f"Too many tensor indices {tensor_indices_list} for "
            f"{self.activation_location_type.shape_spec_per_token_sequence=}"
        )
        tensor_indices_list.extend(
            ["All"]
            * (
                len(self.activation_location_type.shape_spec_per_token_sequence)
                - len(self.tensor_indices)
            )
        )
        assert len(tensor_indices_list) == len(
            self.activation_location_type.shape_spec_per_token_sequence
        )
        return dict(
            zip(
                self.activation_location_type.shape_spec_per_token_sequence,
                tensor_indices_list,
            )
        )