def varsize_gather_nograd()

in src/dist_utils.py [0:0]


def varsize_gather_nograd(x: torch.Tensor):
    """gather tensors of different sizes along the first dimension"""

    #determine max size
    size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
    allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
    dist.all_gather(allsizes, size)
    max_size = max([size.cpu().max() for size in allsizes])

    padded = torch.empty(
                max_size, 
                *x.shape[1:],
                dtype=x.dtype,
                device=x.device
            )
    padded[:x.shape[0]] = x
    output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
    dist.all_gather(output, padded)

    output = [tensor[:allsizes[k]] for k, tensor in enumerate(output)]
    output = torch.cat(output, dim=0)

    return output