def __init__()

in picotron/data_parallel/bucket.py [0:0]


    def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
        """
        Initializes the BucketManager.

        Args:
            params (List[torch.nn.Parameter]): List of model parameters.
            process_group (torch.distributed.ProcessGroup): Process group used for gradient synchronization.
            bucket_size (int): Maximum size of each bucket in terms of gradient elements.
            grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
        """
        self.params = list(params) # Convert parameter generator to a list.
        self.device = self.params[0].device if self.params[0].is_cuda else torch.device("cpu")
        self.buckets = [] # List of buckets.
        self.process_group = process_group
        self.process_group_size = dist.get_world_size(group=self.process_group)
        self.params_to_bucket_location = {} # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
        self.bucket_size = bucket_size
        self.bucket_sizes = None # Actual sizes of each bucket.
        self.grad_data_list = [] # List of tensors to store gradients, one tensor per bucket.
        self.grad_type = grad_type
        # Divide gradients into buckets based on the provided bucket size.
        self._initialize_buckets()