in finetune/PyTorch/azureml_bert_util.py [0:0]
def synchronize(self):
synced = False
if self.count_down == 0:
missing_p = self._requires_update - set(self._handles.keys())
for p in missing_p:
self._allreduce_tensor(p)
if self._multi_node:
for p, value in self._handles.items():
handle, ctx = value
output = synchronize(handle)
p.grad.set_(self._compression.decompress(output, ctx) / self.accumulation_step)
else:
buckets = OrderedDict()
for tensor in self._handles.values():
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
for tp in buckets:
bucket = buckets[tp]
coalesced = flatten(bucket) / self.world_size / self.accumulation_step
torch.distributed.all_reduce_multigpu([coalesced])
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
self._handles.clear()
synced = True
self.count_down = self.accumulation_step
self.count_down -= 1
return synced