in tzrec/utils/dist_util.py [0:0]
def gather_strings(s: str, dst: int = 0) -> List[str]:
"""Gather strings from all ranks to the destination rank."""
rank = dist.get_rank()
world_size = dist.get_world_size()
s_tensor = torch.ByteTensor(bytearray(s, "utf-8"))
max_len = torch.tensor([len(s_tensor)], dtype=torch.long)
max_len_list = [torch.tensor([0], dtype=torch.long) for _ in range(world_size)]
if dist.get_backend() == dist.Backend.NCCL:
max_len = max_len.cuda()
max_len_list = [x.cuda() for x in max_len_list]
dist.all_gather(max_len_list, max_len)
# pyre-ignore [6]
max_len = max(max_len_list).item()
padded_s_tensor = torch.cat(
(s_tensor, torch.zeros(max_len - len(s_tensor), dtype=torch.uint8))
)
if rank == dst:
gather_list = [
torch.zeros(max_len, dtype=torch.uint8) for _ in range(world_size)
]
else:
gather_list = []
if dist.get_backend() == dist.Backend.NCCL:
padded_s_tensor = padded_s_tensor.cuda()
gather_list = [x.cuda() for x in gather_list]
dist.gather(padded_s_tensor, gather_list, dst)
gathered_strings = []
if rank == dst:
for tensor in gather_list:
string = tensor.cpu().numpy().tobytes().decode("utf-8").rstrip("\x00")
gathered_strings.append(string)
return gathered_strings