def update_from_autoencoder_node_type()

in neuron_explainer/activations/derived_scalars/derived_scalar_types.py [0:0]


    def update_from_autoencoder_node_type(self, node_type: NodeType | None) -> "DerivedScalarType":
        """
        When multiple autoencoders are used, the DST needs to be specific to the autoencoder type.
        This function updates the DST to be specific to the autoencoder type:
            - NodeType.AUTOENCODER_LATENT: default autoencoder, used when no specific autoencoder type is specified
            - NodeType.MLP_AUTOENCODER_LATENT: autoencoder trained on activations from an MLP layer
            - NodeType.ATTENTION_AUTOENCODER_LATENT: autoencoder trained on activations from an Attention layer
        """
        node_type = node_type or NodeType.AUTOENCODER_LATENT
        assert node_type.is_autoencoder_latent
        new_dst_by_node_type = {
            DerivedScalarType.AUTOENCODER_LATENT: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.AUTOENCODER_LATENT,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.MLP_AUTOENCODER_LATENT,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ATTENTION_AUTOENCODER_LATENT,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_LATENT: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_LATENT,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_ACT_TIMES_GRAD,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_ACT_TIMES_GRAD,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_WRITE: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE,
            },
            DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_NORM,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_NORM,
            },
            DerivedScalarType.AUTOENCODER_WRITE_NORM: {
                NodeType.AUTOENCODER_LATENT: DerivedScalarType.AUTOENCODER_WRITE_NORM,
                NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.MLP_AUTOENCODER_WRITE_NORM,
                NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ATTENTION_AUTOENCODER_WRITE_NORM,
            },
        }[self]

        return new_dst_by_node_type[node_type]