def ndim()

in neuron_explainer/activations/derived_scalars/indexing.py [0:0]


    def ndim(self) -> int:
        match self.node_type:
            case NodeType.ATTENTION_HEAD:
                reference_activation_location_type = ActivationLocationType.ATTN_QK_PROBS
            case NodeType.MLP_NEURON:
                reference_activation_location_type = ActivationLocationType.MLP_POST_ACT
            case NodeType.AUTOENCODER_LATENT:
                reference_activation_location_type = (
                    ActivationLocationType.ONLINE_AUTOENCODER_LATENT
                )
            case NodeType.MLP_AUTOENCODER_LATENT:
                reference_activation_location_type = (
                    ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT
                )
            case NodeType.ATTENTION_AUTOENCODER_LATENT:
                reference_activation_location_type = (
                    ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT
                )
            case NodeType.RESIDUAL_STREAM_CHANNEL:
                reference_activation_location_type = ActivationLocationType.RESID_POST_MLP
            case _:
                raise NotImplementedError(f"Node type {self.node_type} not supported")
        return compute_indexed_tensor_ndim(
            activation_location_type=reference_activation_location_type,
            tensor_indices=self.tensor_indices,
        )