in neuron_explainer/activations/derived_scalars/derived_scalar_types.py [0:0]
def node_type(self) -> NodeType:
"""
The last index of a tensor of derived scalars can correspond to a type of object in the
network called a 'node'. This can be an MLP neuron, an attention head, an autoencoder
latent, etc. If we don't yet have a name for the last dimension of a derived scalar type,
this throws an error.
"""
if self.is_autoencoder_latent:
if "mlp" in self.value:
return NodeType.MLP_AUTOENCODER_LATENT
elif "attention" in self.value:
return NodeType.ATTENTION_AUTOENCODER_LATENT
else:
return NodeType.AUTOENCODER_LATENT
last_dimension = self.shape_spec_per_token_sequence[-1]
if last_dimension in node_type_by_dimension:
return node_type_by_dimension[last_dimension]
else:
raise NotImplementedError(f"Unknown node type for {last_dimension=}")