def backward()

in src/dist_utils.py [0:0]


    def backward(*grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]