def gather_nograd()

in src/dist_utils.py [0:0]


def gather_nograd(x: torch.tensor):
    x_gather = [torch.ones_like(x)
        for _ in range(dist.get_world_size())]
    dist.all_gather(x_gather, x, async_op=False)

    x_gather = torch.cat(x_gather, dim=0)
    return x_gather