in neuron_explainer/activations/derived_scalars/derived_scalar_types.py [0:0]
def update_from_autoencoder_node_type(self, node_type: NodeType | None) -> "DerivedScalarType":
"""
When multiple autoencoders are used, the DST needs to be specific to the autoencoder type.
This function updates the DST to be specific to the autoencoder type:
- NodeType.AUTOENCODER_LATENT: default autoencoder, used when no specific autoencoder type is specified
- NodeType.MLP_AUTOENCODER_LATENT: autoencoder trained on activations from an MLP layer
- NodeType.ATTENTION_AUTOENCODER_LATENT: autoencoder trained on activations from an Attention layer
"""
node_type = node_type or NodeType.AUTOENCODER_LATENT
assert node_type.is_autoencoder_latent
new_dst_by_node_type = {
DerivedScalarType.AUTOENCODER_LATENT: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.MLP_AUTOENCODER_LATENT,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ATTENTION_AUTOENCODER_LATENT,
},
DerivedScalarType.ONLINE_AUTOENCODER_LATENT: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT,
},
DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_ACT_TIMES_GRAD,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_ACT_TIMES_GRAD,
},
DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
},
DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD,
},
DerivedScalarType.ONLINE_AUTOENCODER_WRITE: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE,
},
DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_NORM,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_NORM,
},
DerivedScalarType.AUTOENCODER_WRITE_NORM: {
NodeType.AUTOENCODER_LATENT: DerivedScalarType.AUTOENCODER_WRITE_NORM,
NodeType.MLP_AUTOENCODER_LATENT: DerivedScalarType.MLP_AUTOENCODER_WRITE_NORM,
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ATTENTION_AUTOENCODER_WRITE_NORM,
},
}[self]
return new_dst_by_node_type[node_type]