in picotron/data_parallel/bucket.py [0:0]
def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
"""
Initializes a Bucket instance.
Args:
params (List[torch.nn.Parameter]): List of parameters assigned to this bucket.
grad_data (torch.Tensor): Tensor to store the gradients for this bucket.
process_group (torch.distributed.ProcessGroup): Process group used for synchronizing gradients.
"""
self.params = set(params) # Set of parameters in this bucket.
self.params_with_grad_ready = set() # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
self.grad_data = grad_data # Tensor that stores gradients for all parameters in this bucket.
self.process_group = process_group # Process group for gradient synchronization.
self.process_group_size = dist.get_world_size(group=self.process_group)
self.handle = None # Handle for the async allreduce operation.
self.reset()