def requires_grad_for_forward_pass()

in neuron_explainer/activations/derived_scalars/derived_scalar_types.py [0:0]


    def requires_grad_for_forward_pass(self) -> bool:
        return self in {
            DerivedScalarType.MLP_ACT_TIMES_GRAD,
            DerivedScalarType.MLP_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ATTN_ACT_TIMES_GRAD,
            DerivedScalarType.ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ATTN_ACT_TIMES_GRAD_PER_SEQUENCE_TOKEN,
            DerivedScalarType.ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD_PER_SEQUENCE_TOKEN,
            DerivedScalarType.UNFLATTENED_ATTN_ACT_TIMES_GRAD,
            DerivedScalarType.UNFLATTENED_ATTN_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD,
            DerivedScalarType.ONLINE_MLP_AUTOENCODER_ACT_TIMES_GRAD,
            DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_ACT_TIMES_GRAD,
            DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ONLINE_MLP_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.RESID_POST_EMBEDDING_PROJ_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.RESID_POST_MLP_PROJ_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.MLP_LAYER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.RESID_POST_ATTN_PROJ_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ATTN_LAYER_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_ACT_TIMES_GRAD,
            DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_WRITE_TO_FINAL_RESIDUAL_GRAD,
            DerivedScalarType.TOKEN_ATTRIBUTION,
            DerivedScalarType.GRAD_OF_SINGLE_SUBNODE_ATTRIBUTION,
            DerivedScalarType.ATTN_OUT_EDGE_ATTRIBUTION,
            DerivedScalarType.MLP_OUT_EDGE_ATTRIBUTION,
            DerivedScalarType.ONLINE_AUTOENCODER_OUT_EDGE_ATTRIBUTION,
            DerivedScalarType.ATTN_QUERY_IN_EDGE_ATTRIBUTION,
            DerivedScalarType.ATTN_KEY_IN_EDGE_ATTRIBUTION,
            DerivedScalarType.ATTN_VALUE_IN_EDGE_ATTRIBUTION,
            DerivedScalarType.MLP_IN_EDGE_ATTRIBUTION,
            DerivedScalarType.ONLINE_AUTOENCODER_IN_EDGE_ATTRIBUTION,
            DerivedScalarType.TOKEN_OUT_EDGE_ATTRIBUTION,
            DerivedScalarType.SINGLE_NODE_WRITE_TO_FINAL_RESIDUAL_GRAD,
        }