def __init__()

in step7_pipeline_parallel_afab/data_parallel.py [0:0]


    def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
        # Set of parameters in this bucket.
        self.params = set(params)
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.params_with_grad_ready = set()
        # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
        self.grad_data = grad_data
        # Process group for gradient synchronization.
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group) 
        # Handle for the async allreduce operation.
        self.handle = None
        
        self.reset()