neuron_explainer/activations/derived_scalars/edge_attribution.py (429 lines of code) (raw):
"""
This file contains functions for computing the importance of edges in a transformer computation
graph.
Edges are taken to go from an upstream node (defined to be an MLP neuron, autoencoder latent, or
attention head at a specific token or token pair) to a downstream "subnode" (defined to be an MLP
neuron, autoencoder latent, or attention head Q, K, or V at a specific token or token pair).
Notice that we consider Q, K, V subnodes for attention separately for the downstream partner of the
edge, but lump them together for the upstream partner.
Note that the inputs to an attention head node are specified as either Q, K, or V-mediated, while
the outputs are specified merely as originating from the attention head node.
'Importance' of an edge is defined using act * grad, or 'attribution'.
See eq. 2 of https://arxiv.org/pdf/2310.10348.pdf for more context, but here briefly:
The attribution of an edge can be computed as:
dLoss/d"EdgeActivation" * "EdgeActivation"
= dLoss/dDownstreamSubNodeActivation * ∂DownstreamSubNodeActivation/∂UpstreamNodeActivation * UpstreamNodeActivation
= dLoss/dDownstreamSubNodeActivation * dDownstreamSubNodeActivation/dResidualStream * UpstreamNodeWriteToResidualStream
where the Write terms indicate (d_model,) vectors per token or per token pair, dX/dY indicates the
total derivative of X with respect to Y, and ∂X/∂Y indicates the partial derivative. "Total
derivatives" are also known as "gradients", while "partial derivatives" are also known as "direct
writes" to gradient directions (i.e. dLoss/dDownstreamSubNodeActivation is the gradient at the
downstream node, while ∂DownstreamSubNodeActivation/∂UpstreamNodeActivation is the "direct write"
from the upstream to the downstream node). "EdgeActivation" is the considered to be the direct
effect of the upstream node's activation being patched to the downstream subnode's input. The
"Activation" of a node or subnode is considered to be any sufficient statistic for determining that
node's effect on downstream model components (e.g. the activation of a single MLP neuron, or all
d_head channels of an attention query at a particular pair of tokens). "ResidualStream" refers to
the residual stream just before the downstream subnode in question (edges correspond to direct
writes between nodes).
The strategy within this file is:
1. construct a function to compute
[dLoss/dDownstreamSubNodeActivation * dDownstreamSubNodeActivation(ResidualStream)](ResidualStream)
:=DownstreamSubNodeAttribution(ResidualStream)
with a stopgrad on the dLoss/dDownstreamSubNodeActivation term, for ONE OR MORE downstream subnodes.
This is flexible for use with one or more downstream subnodes to support reuse in two contexts: many
upstream to one downstream node, or one upstream to many downstream nodes.
(make_reconstituted_attribution_fn, which is used to construct AttributionReconstituter)
### MANY-UPSTREAM-TO-ONE-DOWNSTREAM CASE ###
2. construct a function to compute dDownstreamSubNodeAttribution/dResidualStream for JUST ONE
downstream subnode (using AttributionReconstituter)
3. construct a ScalarDeriver by taking the inner product of MANY upstream nodes' write vectors with
the gradient of the attribution of ONE downstream subnode
(convert_scalar_deriver_to_out_edge_attribution)
### ONE-UPSTREAM-TO-MANY-DOWNSTREAM CASE ###
4. construct a function to compute dDownstreamSubNodeAttribution/dResidualStream * WriteVector for
MANY downstream subnodes and ONE upstream node (using AttributionReconstituter)
5. construct a ScalarDeriver for the write vector of a single upstream node (note that this write
vector is per token, even if the upstream node is per token pair; in this case it will be the
contribution of just one (sequence token, attended to token) pair to the sequence token)
(make_node_write_scalar_deriver)
6. convert this ScalarDeriver to a ScalarDeriver for the in edge attribution of many downstream
subnodes, originating from one upstream node
(convert_node_write_scalar_deriver_to_in_edge_attribution,
make_in_edge_attribution_scalar_deriver_factory)
An AttributionReconstituter object is used to reconstruct the attribution of the downstream node(s).
The attribution of the edge is computed by taking derivatives of the downstream node attribution
with respect to the residual stream (where derivatives are either gradients, in the case where there
is one downstream node, or Jacobians, in the case where there are many downstream nodes).
"""
import dataclasses
from typing import Callable
import torch
from neuron_explainer.activations.derived_scalars.attention import (
make_attn_weighted_value_scalar_deriver,
)
from neuron_explainer.activations.derived_scalars.autoencoder import (
make_online_autoencoder_latent_scalar_deriver_factory,
)
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.activations.derived_scalars.direct_effects import (
convert_scalar_deriver_to_write_to_direction,
convert_scalar_deriver_to_write_to_final_residual_grad,
)
from neuron_explainer.activations.derived_scalars.indexing import (
AttentionTraceType,
AttnSubNodeIndex,
NodeIndex,
PreOrPostAct,
)
from neuron_explainer.activations.derived_scalars.locations import (
ConstantLayerIndexer,
IdentityLayerIndexer,
get_previous_residual_dst_for_node_type,
)
from neuron_explainer.activations.derived_scalars.mlp import get_base_mlp_scalar_deriver
from neuron_explainer.activations.derived_scalars.node_write import (
make_node_write_scalar_deriver,
make_node_write_scalar_source,
)
from neuron_explainer.activations.derived_scalars.raw_activations import (
make_scalar_deriver_factory_for_activation_location_type,
)
from neuron_explainer.activations.derived_scalars.reconstituted import (
make_apply_attn_V_act,
make_reconstituted_activation_fn,
)
from neuron_explainer.activations.derived_scalars.reconstituter_class import Reconstituter
from neuron_explainer.activations.derived_scalars.scalar_deriver import (
DerivedScalarSource,
DstConfig,
RawScalarSource,
ScalarDeriver,
ScalarSource,
)
from neuron_explainer.models.autoencoder_context import AutoencoderContext
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
LayerIndex,
NodeType,
PassType,
)
from neuron_explainer.models.model_context import StandardModelContext
from neuron_explainer.models.transformer import Transformer
def get_activation_location_type_for_node_type(
node_type: NodeType, q_k_or_v: ActivationLocationType | None
) -> ActivationLocationType:
"""This returns the activation location associated with a node of a given type, and
specifying Q, K, or V in the case of attention. This returns an activation location type
that is sufficient to determine that node's effect on the residual stream (post-softmax
in the case of Q, K; ATTN_WEIGHTED_SUM_OF_VALUES in the case of V)"""
match node_type:
case NodeType.ATTENTION_HEAD:
assert q_k_or_v is not None
match q_k_or_v:
case ActivationLocationType.ATTN_VALUE:
return ActivationLocationType.ATTN_WEIGHTED_SUM_OF_VALUES
case ActivationLocationType.ATTN_QUERY | ActivationLocationType.ATTN_KEY:
return ActivationLocationType.ATTN_QK_PROBS
case _:
raise NotImplementedError(q_k_or_v)
case NodeType.MLP_NEURON:
assert q_k_or_v is None
return ActivationLocationType.MLP_POST_ACT
case NodeType.AUTOENCODER_LATENT:
assert q_k_or_v is None
return ActivationLocationType.ONLINE_AUTOENCODER_LATENT
case NodeType.MLP_AUTOENCODER_LATENT:
assert q_k_or_v is None
return ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT
case NodeType.ATTENTION_AUTOENCODER_LATENT:
assert q_k_or_v is None
return ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT
case _:
raise NotImplementedError(node_type)
def make_reconstituted_attribution_fn(
transformer: Transformer,
autoencoder_context: AutoencoderContext | None,
node_type: NodeType, # the type of the node of interest
q_k_or_v: (
ActivationLocationType | None
), # if node_type is ATTENTION_HEAD, this specifies Q, K, or V
detach_layer_norm_scale: bool,
) -> Callable[[torch.Tensor, torch.Tensor, LayerIndex, PassType], torch.Tensor]:
"""The 'attribution' of a node is taken to be the product of the node's gradient and the node's activation.
This returns a function to compute the attribution of attention heads (specifically mediated by Q, K, or V), MLP activations,
or autoencoder activations. The input expected by that function is the residual stream just before the node in question.
The function returned can be used for further analysis, e.g. computing the gradient of the attribution with respect to the
input residual stream.
Note that this can be used to compute the attribution of one OR many downstream subnodes, depending on the tensor_indices_for_grad
(an empty tensor_indices_for_grad corresponds to the entire layer worth of activations)."""
assert (q_k_or_v is None) == (
node_type != NodeType.ATTENTION_HEAD
) # for these functions, we require q_k_or_v to be
# specified if node_type is ATTENTION_HEAD
match node_type:
case NodeType.ATTENTION_HEAD:
match q_k_or_v:
case ActivationLocationType.ATTN_QUERY | ActivationLocationType.ATTN_KEY:
# in all cases but attn value, the attribution fn is the hadamard product of the activation and the gradient
# NOTE: q_k_or_v = None covers all non-attention activations
match q_k_or_v:
case ActivationLocationType.ATTN_QUERY:
attention_trace_type = AttentionTraceType.Q
case ActivationLocationType.ATTN_KEY:
attention_trace_type = AttentionTraceType.K
case None:
attention_trace_type = AttentionTraceType.QK
case _:
raise NotImplementedError(q_k_or_v)
activation_fn = make_reconstituted_activation_fn(
transformer=transformer,
autoencoder_context=autoencoder_context,
node_type=node_type,
pre_or_post_act=PreOrPostAct.POST,
detach_layer_norm_scale=detach_layer_norm_scale,
attention_trace_type=attention_trace_type,
)
def attribution_fn(
resid: torch.Tensor,
grad: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
activation = activation_fn(resid, layer_index, pass_type)
assert activation.shape == grad.shape, (
activation.shape,
grad.shape,
)
return activation * grad
case ActivationLocationType.ATTN_VALUE:
assert (
get_activation_location_type_for_node_type(node_type, q_k_or_v)
== ActivationLocationType.ATTN_WEIGHTED_SUM_OF_VALUES
)
apply_attn_V_act = make_apply_attn_V_act(
transformer=transformer,
q_k_or_v=q_k_or_v,
detach_layer_norm_scale=detach_layer_norm_scale,
)
def attribution_fn(
resid: torch.Tensor,
grad: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
attn, V = apply_attn_V_act(resid, layer_index, pass_type)
attn_weighted_V = torch.einsum("qkh,khd->qkhd", attn, V)
# grad is w/r/t (attn_weighted_V summed over k, or ATTN_WEIGHTED_SUM_OF_VALUES)
return torch.einsum("qkhd,qhd->qkh", attn_weighted_V, grad)
case _:
raise NotImplementedError(q_k_or_v)
case _:
activation_fn = make_reconstituted_activation_fn(
transformer=transformer,
autoencoder_context=autoencoder_context,
node_type=node_type,
pre_or_post_act=PreOrPostAct.POST,
detach_layer_norm_scale=detach_layer_norm_scale,
attention_trace_type=None,
)
def attribution_fn(
resid: torch.Tensor,
grad: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
activation = activation_fn(resid, layer_index, pass_type)
assert activation.shape == grad.shape, (
activation.shape,
grad.shape,
)
return activation * grad
return attribution_fn
class AttributionReconstituter(Reconstituter):
"""Reconstitute MLP, autoencoder, or attention node attribution (act * grad at node).
Attention nodes are required to be split into Q, K, or V subnodes."""
requires_other_scalar_source = True
def __init__(
self,
transformer: Transformer,
autoencoder_context: AutoencoderContext | None,
node_type: NodeType,
q_k_or_v: ActivationLocationType | None,
detach_layer_norm_scale: bool,
):
super().__init__()
self._reconstitute_activations_fn = make_reconstituted_attribution_fn(
transformer=transformer,
autoencoder_context=autoencoder_context,
node_type=node_type,
q_k_or_v=q_k_or_v,
detach_layer_norm_scale=detach_layer_norm_scale,
)
self.node_type = node_type
self.q_k_or_v = q_k_or_v
self.residual_dst = get_previous_residual_dst_for_node_type(
node_type=node_type,
autoencoder_dst=autoencoder_context.dst if autoencoder_context is not None else None,
)
def reconstitute_activations(
self,
resid: torch.Tensor,
grad: torch.Tensor | None,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert grad is not None
return self._reconstitute_activations_fn(
resid,
grad,
layer_index,
pass_type,
)
def make_other_scalar_source(self, _unused_dst_config: DstConfig) -> ScalarSource:
activation_location_type = get_activation_location_type_for_node_type(
node_type=self.node_type,
q_k_or_v=self.q_k_or_v,
)
return RawScalarSource(
activation_location_type=activation_location_type,
pass_type=PassType.BACKWARD,
layer_indexer=IdentityLayerIndexer(),
) # this provides the 'grad' argument required by reconstitute_activations
def _check_node_index(self, node_index: NodeIndex) -> None:
assert node_index.node_type == self.node_type
assert node_index.pass_type == PassType.FORWARD
assert node_index.layer_index is not None
if node_index.node_type == NodeType.ATTENTION_HEAD:
assert isinstance(node_index, AttnSubNodeIndex)
assert node_index.q_k_or_v == self.q_k_or_v
def make_scalar_hook_for_node_index(
self, node_index: NodeIndex
) -> Callable[[torch.Tensor], torch.Tensor]:
self._check_node_index(node_index)
def get_activation_from_layer_activations(layer_activations: torch.Tensor) -> torch.Tensor:
return layer_activations[node_index.tensor_indices]
return get_activation_from_layer_activations
def make_gradient_scalar_deriver_for_node_index(
self,
node_index: NodeIndex,
dst_config: DstConfig,
output_dst: DerivedScalarType | None = None,
) -> ScalarDeriver:
self._check_node_index(node_index)
assert node_index.layer_index is not None
dst_config_for_layer = dataclasses.replace(
dst_config,
layer_indices=[node_index.layer_index],
)
scalar_hook = self.make_scalar_hook_for_node_index(node_index)
return self.make_gradient_scalar_deriver(
scalar_hook=scalar_hook,
dst_config=dst_config_for_layer,
output_dst=output_dst,
)
def make_gradient_scalar_source_for_node_index(
self,
node_index: NodeIndex,
dst_config: DstConfig,
output_dst: DerivedScalarType | None = None,
) -> DerivedScalarSource:
scalar_hook = self.make_scalar_hook_for_node_index(node_index)
gradient_scalar_deriver = self.make_gradient_scalar_deriver(
scalar_hook=scalar_hook,
dst_config=dst_config,
output_dst=output_dst,
)
assert node_index.layer_index is not None
return DerivedScalarSource(
scalar_deriver=gradient_scalar_deriver,
pass_type=PassType.FORWARD,
layer_indexer=ConstantLayerIndexer(node_index.layer_index),
)
def _make_attribution_reconstituter_for_one_downstream_node(
dst_config: DstConfig,
) -> AttributionReconstituter:
# in the case of computing attribution of edges from many upstream to one downstream node, the dst_config
# contains the information necessary to construct the Reconstituter. This is because the activation being
# reconstituted corresponds to dst_config.node_index_for_attribution
node_index_for_attribution = dst_config.node_index_for_attribution
assert node_index_for_attribution is not None
node_type = node_index_for_attribution.node_type
if isinstance(node_index_for_attribution, AttnSubNodeIndex):
q_k_or_v = node_index_for_attribution.q_k_or_v
else:
q_k_or_v = None
assert (node_type == NodeType.ATTENTION_HEAD) == (q_k_or_v is not None)
model_context = dst_config.get_model_context()
transformer = model_context.get_or_create_model()
autoencoder_context = dst_config.get_autoencoder_context()
return AttributionReconstituter(
transformer=transformer,
autoencoder_context=autoencoder_context,
node_type=node_type,
q_k_or_v=q_k_or_v,
detach_layer_norm_scale=dst_config.detach_layer_norm_scale_for_attribution,
)
### MANY-UPSTREAM-TO-ONE-DOWNSTREAM CASE ###
def make_grad_of_downstream_subnode_attribution_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
"""Computes the gradient with respect to the preceding residual stream of the downstream subnode's attribution
(d(Activation(ResidualStream) * Gradient)/dResidualStream), with a stopgrad on the "Gradient" term.
"""
node_index = dst_config.node_index_for_attribution
assert node_index is not None
reconstituter = _make_attribution_reconstituter_for_one_downstream_node(dst_config)
return reconstituter.make_gradient_scalar_deriver_for_node_index(
node_index=node_index,
dst_config=dst_config,
output_dst=DerivedScalarType.GRAD_OF_SINGLE_SUBNODE_ATTRIBUTION,
)
def convert_scalar_deriver_to_out_edge_attributions(
scalar_deriver: ScalarDeriver,
output_dst: DerivedScalarType,
) -> ScalarDeriver:
"""Converts a scalar deriver for an activation of some kind to a scalar deriver for the
attribution of edges going out from that activation to the node specified by
trace_config (which can be autoencoder, MLP, or attention head-- and in the case of
attention head, specifically the edge going to Q, K, or V)."""
reconstituter = _make_attribution_reconstituter_for_one_downstream_node(
scalar_deriver.dst_config,
)
node_index = scalar_deriver.dst_config.node_index_for_attribution
assert node_index is not None
attribution_grad_scalar_source = reconstituter.make_gradient_scalar_source_for_node_index(
node_index=node_index,
dst_config=scalar_deriver.dst_config,
output_dst=DerivedScalarType.GRAD_OF_SINGLE_SUBNODE_ATTRIBUTION,
)
return convert_scalar_deriver_to_write_to_direction(
scalar_deriver=scalar_deriver,
direction_scalar_source=attribution_grad_scalar_source,
output_dst=output_dst,
)
def make_attn_out_edge_attribution_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
"""Returns a scalar deriver for the attention value weighted by the post-softmax
attention between each pair of tokens."""
attn_weighted_value_scalar_deriver = make_attn_weighted_value_scalar_deriver(dst_config)
return convert_scalar_deriver_to_out_edge_attributions(
scalar_deriver=attn_weighted_value_scalar_deriver,
output_dst=DerivedScalarType.ATTN_OUT_EDGE_ATTRIBUTION,
)
def make_mlp_out_edge_attribution_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
"""Returns a scalar deriver for the edge attribution of the MLP output layer at each token."""
scalar_deriver = get_base_mlp_scalar_deriver(
dst_config=dst_config,
)
return convert_scalar_deriver_to_out_edge_attributions(
scalar_deriver=scalar_deriver,
output_dst=DerivedScalarType.MLP_OUT_EDGE_ATTRIBUTION,
)
def make_online_autoencoder_out_edge_attribution_scalar_deriver(
dst_config: DstConfig,
node_type: NodeType | None = None,
) -> ScalarDeriver:
"""Returns a scalar deriver for the edge attribution of the MLP output layer at each token."""
scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(dst_config)
return convert_scalar_deriver_to_out_edge_attributions(
scalar_deriver=scalar_deriver,
output_dst=DerivedScalarType.ONLINE_AUTOENCODER_OUT_EDGE_ATTRIBUTION,
)
def make_token_out_edge_attribution_scalar_deriver(dst_config: DstConfig) -> ScalarDeriver:
"""This computes an attribution value for the edge from each token in the sequence to a particular
downstream node."""
node_index = dst_config.node_index_for_attribution
assert node_index is not None
reconstituter = _make_attribution_reconstituter_for_one_downstream_node(dst_config)
emb_scalar_deriver = make_scalar_deriver_factory_for_activation_location_type(
activation_location_type=ActivationLocationType.RESID_POST_EMBEDDING,
)(dst_config)
attribution_grad_scalar_source = reconstituter.make_gradient_scalar_source_for_node_index(
node_index=node_index,
dst_config=emb_scalar_deriver.dst_config,
output_dst=DerivedScalarType.GRAD_OF_SINGLE_SUBNODE_ATTRIBUTION,
)
return convert_scalar_deriver_to_write_to_direction(
scalar_deriver=emb_scalar_deriver,
direction_scalar_source=attribution_grad_scalar_source,
output_dst=DerivedScalarType.TOKEN_OUT_EDGE_ATTRIBUTION,
)
### ONE-UPSTREAM-TO-MANY-DOWNSTREAM CASE ###
def convert_node_write_scalar_deriver_to_in_edge_attribution(
node_write_scalar_source: ScalarSource,
output_dst: DerivedScalarType,
dst_config: DstConfig,
downstream_node_type: NodeType,
downstream_q_k_or_v: ActivationLocationType | None,
) -> ScalarDeriver:
"""Converts a scalar deriver for a write vector from some upstream node type to a scalar deriver for
in edge attribution for downstream nodes of some type (MLP, autoencoder, or attention head)."""
model_context = dst_config.get_model_context()
autoencoder_context = dst_config.get_autoencoder_context()
assert isinstance(model_context, StandardModelContext)
transformer = model_context.get_or_create_model()
reconstituter = AttributionReconstituter(
transformer=transformer,
autoencoder_context=autoencoder_context,
node_type=downstream_node_type,
q_k_or_v=downstream_q_k_or_v,
detach_layer_norm_scale=dst_config.detach_layer_norm_scale_for_attribution,
)
return reconstituter.make_jvp_scalar_deriver(
write_scalar_source=node_write_scalar_source,
dst_config=dst_config,
output_dst=output_dst,
)
def make_in_edge_attribution_scalar_deriver_factory(
node_type_for_attribution: NodeType,
q_k_or_v_for_attribution: ActivationLocationType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
"""Returns a function that creates a scalar deriver for the edge attribution from arbitrary node
to the specified downstream node type / sub node type (MLP, autoencoder, or attention head Q, K, or V).
"""
sub_node_type_to_output_dst = {
(NodeType.MLP_NEURON, None): DerivedScalarType.MLP_IN_EDGE_ATTRIBUTION,
(
NodeType.AUTOENCODER_LATENT,
None,
): DerivedScalarType.ONLINE_AUTOENCODER_IN_EDGE_ATTRIBUTION,
(
NodeType.ATTENTION_HEAD,
ActivationLocationType.ATTN_QUERY,
): DerivedScalarType.ATTN_QUERY_IN_EDGE_ATTRIBUTION,
(
NodeType.ATTENTION_HEAD,
ActivationLocationType.ATTN_KEY,
): DerivedScalarType.ATTN_KEY_IN_EDGE_ATTRIBUTION,
(
NodeType.ATTENTION_HEAD,
ActivationLocationType.ATTN_VALUE,
): DerivedScalarType.ATTN_VALUE_IN_EDGE_ATTRIBUTION,
}
output_dst = sub_node_type_to_output_dst[(node_type_for_attribution, q_k_or_v_for_attribution)]
def make_in_edge_attribution_scalar_deriver(dst_config: DstConfig) -> ScalarDeriver:
node_write_scalar_source = make_node_write_scalar_source(dst_config)
return convert_node_write_scalar_deriver_to_in_edge_attribution(
node_write_scalar_source=node_write_scalar_source,
output_dst=output_dst,
dst_config=dst_config,
downstream_node_type=node_type_for_attribution,
downstream_q_k_or_v=q_k_or_v_for_attribution,
)
return make_in_edge_attribution_scalar_deriver
def make_node_write_to_final_residual_grad_scalar_deriver(dst_config: DstConfig) -> ScalarDeriver:
"""Returns a scalar deriver for the write vector from some upstream node type
(MLP, autoencoder, or attention head) to the final residual grad. This can be used to compute
the edge attribution of the edge from that node to the loss itself."""
node_write_scalar_deriver = make_node_write_scalar_deriver(
dst_config
) # TODO: figure out how to thread
# the correct layer through to the final residual grad scalar source
return convert_scalar_deriver_to_write_to_final_residual_grad(
node_write_scalar_deriver,
output_dst=DerivedScalarType.SINGLE_NODE_WRITE_TO_FINAL_RESIDUAL_GRAD,
use_existing_backward_pass_for_final_residual_grad=True,
)