in src/dist_utils.py [0:0]
def backward(*grads): all_gradients = torch.stack(grads) dist.all_reduce(all_gradients) return all_gradients[dist.get_rank()]