in pretrain/PyTorch/distributed_apex.py [0:0]
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
if len(tmp_bucket) > 0:
self.active_i_buckets.append(tmp_bucket)
self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.active_i_buckets)))
#print("Sync Bucket Structure Broadcast. Rank", dist.get_rank(), "Tensor Size ", info_tensor.size(), "Device ", info_tensor.device, "Current Device ", torch.cuda.current_device())
dist.broadcast(info_tensor, 0)
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (
bucket_idx, bucket_loc)
flat_i += 1