in picotron/data_parallel/bucket.py [0:0]
def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
"""
Initializes the BucketManager.
Args:
params (List[torch.nn.Parameter]): List of model parameters.
process_group (torch.distributed.ProcessGroup): Process group used for gradient synchronization.
bucket_size (int): Maximum size of each bucket in terms of gradient elements.
grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
"""
self.params = list(params) # Convert parameter generator to a list.
self.device = self.params[0].device if self.params[0].is_cuda else torch.device("cpu")
self.buckets = [] # List of buckets.
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
self.params_to_bucket_location = {} # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
self.bucket_size = bucket_size
self.bucket_sizes = None # Actual sizes of each bucket.
self.grad_data_list = [] # List of tensors to store gradients, one tensor per bucket.
self.grad_type = grad_type
# Divide gradients into buckets based on the provided bucket size.
self._initialize_buckets()