def step()

in optimum/neuron/accelerate/optimizer.py [0:0]


    def step(self, closure=None):
        from neuronx_distributed import parallel_layers
        from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients

        if self.gradient_state.sync_gradients:
            # For sequence-parallel, we have to explicitly all-reduce the layernorm gradients.
            if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
                allreduce_sequence_parallel_gradients(self.optimizer)

            if isinstance(self.optimizer, ZeroRedundancyOptimizer):
                if self.clip_grad_norm_to_perform is not None:
                    # `ZeroRedundancyOptimizer` does not allow to pass a norm type, it could be done but postponing for
                    # now.
                    self.optimizer.grad_clipping = True
                    self.optimizer.max_norm = self.clip_grad_norm_to_perform["max_norm"]
                else:
                    self.optimizer.grad_clipping = False
                self.optimizer.step(closure=closure)
                # Resetting everything.
                self.optimizer.grad_clipping = False
                self.clip_grad_norm_to_perform = None
            elif (
                self.accelerator_state.distributed_type is DistributedType.XLA
                or self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
            ):
                if parallel_layers.parallel_state.get_data_parallel_size() > 1:
                    bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
                if self.clip_grad_norm_to_perform is not None:
                    parameters = self.clip_grad_norm_to_perform.pop("parameters", None)
                    if parameters is not None:
                        self.grad_norm = parallel_layers.clip_grad_norm(parameters, **self.clip_grad_norm_to_perform)
                    self.clip_grad_norm_to_perform = None
                self.optimizer.step(closure=closure)
            elif self.scaler is not None:
                scale_before = self.scaler.get_scale()
                self.scaler.step(self.optimizer, closure)
                self.scaler.update()
                scale_after = self.scaler.get_scale()
                # If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow.
                self._is_overflow = scale_after < scale_before
            else:
                self.optimizer.step(closure)