in src/utils.py [0:0]
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(outputs, x)
return torch.cat(outputs, 0)
return x