in src/accelerate/accelerator.py [0:0]
def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
"""
Should be used in place of `torch.nn.utils.clip_grad_norm_`.
Returns:
`torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for input, target in dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... if accelerator.sync_gradients:
... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
... optimizer.step()
```
"""
if self.distributed_type == DistributedType.FSDP:
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
if not self.is_fsdp2:
return model.clip_grad_norm_(max_norm, norm_type)
else:
return torch.nn.utils.clip_grad_norm_(
parameters, max_norm, norm_type=norm_type
) # viz: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
elif self.distributed_type == DistributedType.DEEPSPEED:
# DeepSpeed handles gradient clipping internally, but we can retrieve the gradient norm
if self.deepspeed_engine_wrapped is not None:
return self.deepspeed_engine_wrapped.get_global_grad_norm()
return None
elif self.distributed_type == DistributedType.XLA:
# Reduce gradients first for XLA
for acc_opt in self._optimizers:
if not acc_opt.gradient_state.is_xla_gradients_synced:
opt = acc_opt
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
gradients = xm._fetch_gradients(opt)
# Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
# one by one in self.reduce is non-inplace.
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
acc_opt.gradient_state.is_xla_gradients_synced = True
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model.clip_grad_norm_(max_norm, norm_type)
self.unscale_gradients()
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)