in neuron_explainer/activations/derived_scalars/indexing.py [0:0]
def ndim(self) -> int:
match self.node_type:
case NodeType.ATTENTION_HEAD:
reference_activation_location_type = ActivationLocationType.ATTN_QK_PROBS
case NodeType.MLP_NEURON:
reference_activation_location_type = ActivationLocationType.MLP_POST_ACT
case NodeType.AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_AUTOENCODER_LATENT
)
case NodeType.MLP_AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT
)
case NodeType.ATTENTION_AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT
)
case NodeType.RESIDUAL_STREAM_CHANNEL:
reference_activation_location_type = ActivationLocationType.RESID_POST_MLP
case _:
raise NotImplementedError(f"Node type {self.node_type} not supported")
return compute_indexed_tensor_ndim(
activation_location_type=reference_activation_location_type,
tensor_indices=self.tensor_indices,
)