in picotron/data_parallel/bucket.py [0:0]
def _initialize_buckets(self) -> None:
"""
Divides model parameters into buckets for gradient synchronization based on the bucket size.
"""
cur_bucket_size = 0
cur_bucket_idx = 0
# Assign parameters to buckets.
for param in self.params:
if not param.requires_grad:
continue
# If the bucket is empty, add the parameter to the bucket.
if cur_bucket_size == 0:
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
continue
# If the parameter cannot fit in the current bucket, create a new bucket
if cur_bucket_size + param.numel() > self.bucket_size:
cur_bucket_idx += 1
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
else:
self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
cur_bucket_size += param.numel()
# Gather information about the bucket sizes and the parameters in each bucket
bucket_sizes = [0] * (cur_bucket_idx + 1)
buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
for param, (_, end, idx) in self.params_to_bucket_location.items():
bucket_sizes[idx] = max(bucket_sizes[idx], end)
buckets_to_params[idx].append(param)
# Create tensors for storing gradients and initialize Bucket objects.
for i in range(len(bucket_sizes)):
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device=self.device))
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
# Create gradient views for each parameter.
for param in self.params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
# param.main_grad is used for gradient calculation
param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)