def gather_strings()

in tzrec/utils/dist_util.py [0:0]


def gather_strings(s: str, dst: int = 0) -> List[str]:
    """Gather strings from all ranks to the destination rank."""
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    s_tensor = torch.ByteTensor(bytearray(s, "utf-8"))

    max_len = torch.tensor([len(s_tensor)], dtype=torch.long)
    max_len_list = [torch.tensor([0], dtype=torch.long) for _ in range(world_size)]
    if dist.get_backend() == dist.Backend.NCCL:
        max_len = max_len.cuda()
        max_len_list = [x.cuda() for x in max_len_list]
    dist.all_gather(max_len_list, max_len)

    # pyre-ignore [6]
    max_len = max(max_len_list).item()
    padded_s_tensor = torch.cat(
        (s_tensor, torch.zeros(max_len - len(s_tensor), dtype=torch.uint8))
    )
    if rank == dst:
        gather_list = [
            torch.zeros(max_len, dtype=torch.uint8) for _ in range(world_size)
        ]
    else:
        gather_list = []
    if dist.get_backend() == dist.Backend.NCCL:
        padded_s_tensor = padded_s_tensor.cuda()
        gather_list = [x.cuda() for x in gather_list]
    dist.gather(padded_s_tensor, gather_list, dst)

    gathered_strings = []
    if rank == dst:
        for tensor in gather_list:
            string = tensor.cpu().numpy().tobytes().decode("utf-8").rstrip("\x00")
            gathered_strings.append(string)

    return gathered_strings