def __init__()

in picotron/data_parallel/bucket.py [0:0]


    def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
        """
        Initializes a Bucket instance.

        Args:
            params (List[torch.nn.Parameter]): List of parameters assigned to this bucket.
            grad_data (torch.Tensor): Tensor to store the gradients for this bucket.
            process_group (torch.distributed.ProcessGroup): Process group used for synchronizing gradients.
        """
        self.params = set(params)    # Set of parameters in this bucket.
        self.params_with_grad_ready = set() # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.grad_data = grad_data  # Tensor that stores gradients for all parameters in this bucket.
        self.process_group = process_group  # Process group for gradient synchronization.
        self.process_group_size = dist.get_world_size(group=self.process_group) 
        self.handle = None # Handle for the async allreduce operation.
        
        self.reset()