in optimum/neuron/accelerate/optimizer.py [0:0]
def step(self, closure=None):
from neuronx_distributed import parallel_layers
from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients
if self.gradient_state.sync_gradients:
# For sequence-parallel, we have to explicitly all-reduce the layernorm gradients.
if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
allreduce_sequence_parallel_gradients(self.optimizer)
if isinstance(self.optimizer, ZeroRedundancyOptimizer):
if self.clip_grad_norm_to_perform is not None:
# `ZeroRedundancyOptimizer` does not allow to pass a norm type, it could be done but postponing for
# now.
self.optimizer.grad_clipping = True
self.optimizer.max_norm = self.clip_grad_norm_to_perform["max_norm"]
else:
self.optimizer.grad_clipping = False
self.optimizer.step(closure=closure)
# Resetting everything.
self.optimizer.grad_clipping = False
self.clip_grad_norm_to_perform = None
elif (
self.accelerator_state.distributed_type is DistributedType.XLA
or self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
):
if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
if self.clip_grad_norm_to_perform is not None:
parameters = self.clip_grad_norm_to_perform.pop("parameters", None)
if parameters is not None:
self.grad_norm = parallel_layers.clip_grad_norm(parameters, **self.clip_grad_norm_to_perform)
self.clip_grad_norm_to_perform = None
self.optimizer.step(closure=closure)
elif self.scaler is not None:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer, closure)
self.scaler.update()
scale_after = self.scaler.get_scale()
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow.
self._is_overflow = scale_after < scale_before
else:
self.optimizer.step(closure)