def node_type()

in neuron_explainer/activations/derived_scalars/derived_scalar_types.py [0:0]


    def node_type(self) -> NodeType:
        """
        The last index of a tensor of derived scalars can correspond to a type of object in the
        network called a 'node'. This can be an MLP neuron, an attention head, an autoencoder
        latent, etc. If we don't yet have a name for the last dimension of a derived scalar type,
        this throws an error.
        """
        if self.is_autoencoder_latent:
            if "mlp" in self.value:
                return NodeType.MLP_AUTOENCODER_LATENT
            elif "attention" in self.value:
                return NodeType.ATTENTION_AUTOENCODER_LATENT
            else:
                return NodeType.AUTOENCODER_LATENT
        last_dimension = self.shape_spec_per_token_sequence[-1]
        if last_dimension in node_type_by_dimension:
            return node_type_by_dimension[last_dimension]
        else:
            raise NotImplementedError(f"Unknown node type for {last_dimension=}")