in utils.py [0:0]
def get_cpu_stats_over_ranks(stat_dict):
keys = sorted(stat_dict.keys())
allreduced = allreduce(torch.stack([torch.as_tensor(stat_dict[k]).detach().cuda().float() for k in keys]), average=True).cpu()
return {k: allreduced[i].item() for (i, k) in enumerate(keys)}