in src/dist_utils.py [0:0]
def forward(x: torch.tensor): output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] dist.all_gather(output, x) return tuple(output)