in picotron/data_parallel/data_parallel.py [0:0]
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""
Creates the a hook for each parameter to handle gradient accumulation and synchronization.
"""
def param_hook(*unused):
"""
The hook called after the gradient is ready. It performs the following:
1. Accumulates the gradient into the main gradient.
2. Adds a post-backward callback to wait for gradient synchronization completion.
3. Marks the parameter as ready for synchronization.
"""
if param.requires_grad:
assert param.grad is not None
param.main_grad.add_(param.grad.data) # accumulate the gradients
param.grad = None
# skip the gradient synchronization (gradient accumulation/PP micro batches)
if self.require_backward_grad_sync:
# Add a callback to wait for gradient synchronization. Ensures the callback is added only once.
# Callback is executed after the backward pass. It should be added per backward pass.
if not self._post_backward_callback_set:
Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
# mark the parameter as ready for gradient synchronization.
bucket_manager.mark_param_as_ready(param)
return param_hook