def synchronize()

in finetune/PyTorch/azureml_bert_util.py [0:0]


    def synchronize(self):
        synced = False
        if self.count_down == 0:
            missing_p = self._requires_update - set(self._handles.keys())
            for p in missing_p:
                self._allreduce_tensor(p)

            if self._multi_node:
                for p, value in self._handles.items():
                    handle, ctx = value
                    output = synchronize(handle)
                    p.grad.set_(self._compression.decompress(output, ctx) / self.accumulation_step)
            else:
                buckets = OrderedDict()
                for tensor in self._handles.values():
                    tp = tensor.type()
                    if tp not in buckets:
                        buckets[tp] = []
                    buckets[tp].append(tensor)
                for tp in buckets:
                    bucket = buckets[tp]
                    coalesced = flatten(bucket) / self.world_size / self.accumulation_step
                    torch.distributed.all_reduce_multigpu([coalesced])
                    for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
                        buf.copy_(synced)
            self._handles.clear()
            synced = True
            self.count_down = self.accumulation_step

        self.count_down -= 1
        return synced