in step8_pipeline_parallel_1f1b/data_parallel.py [0:0]
def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
# Set of parameters in this bucket.
self.params = set(params)
# Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
self.params_with_grad_ready = set()
# Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
self.grad_data = grad_data
# Process group for gradient synchronization.
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
# Handle for the async allreduce operation.
self.handle = None
self.reset()