def apply_gradients()

in training/flax/run_distillation.py [0:0]


    def apply_gradients(self, *, grads, to_dtype: to_fp32, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the
        gradients by the maximum grad norm.

        Note that internally this function calls `.tx.update()` followed by a call
        to `optax.apply_updates()` to update `params` and `opt_state`.

        Args:
          grads: Gradients that have the same pytree structure as `.params`.
          **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

        Returns:
          An updated instance of `self` with `step` incremented by one, `params`
          and `opt_state` updated by applying `grads`, and additional attributes
          replaced as specified by `kwargs`.
        """
        # clip gradients by global l2 norm
        casted_max_grad_norm = to_dtype(self.max_grad_norm)
        g_norm = linear_algebra.global_norm(grads)
        g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
        grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)

        # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
        # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
        updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)

        new_params = optax.apply_updates(self.params, updates)

        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=to_dtype(new_opt_state),
            **kwargs,
        )