in src/dist_utils.py [0:0]
def gather_nograd(x: torch.tensor):
x_gather = [torch.ones_like(x)
for _ in range(dist.get_world_size())]
dist.all_gather(x_gather, x, async_op=False)
x_gather = torch.cat(x_gather, dim=0)
return x_gather