def _bootstrap()

in tzrec/datasets/sampler.py [0:0]


def _bootstrap(group_size: int, local_rank: int, group_rank: int) -> str:
    def addr_to_tensor(ip: str, port: str) -> torch.Tensor:
        addr_array = [int(i) for i in (ip.split("."))] + [int(port)]
        addr_tensor = torch.tensor(addr_array, dtype=torch.int)
        return addr_tensor

    def tensor_to_addr(tensor: torch.Tensor) -> str:
        addr_array = tensor.tolist()
        addr = ".".join([str(i) for i in addr_array[:-1]]) + ":" + str(addr_array[-1])
        return addr

    def exchange_gl_server_info(
        addr_tensor: torch.Tensor, group_size: int, group_rank: int
    ) -> str:
        comm_tensor = torch.zeros([group_size, 5], dtype=torch.int32)
        comm_tensor[group_rank] = addr_tensor
        if dist.get_backend() == dist.Backend.NCCL:
            comm_tensor = comm_tensor.cuda()
        dist.all_reduce(comm_tensor, op=dist.ReduceOp.MAX)
        cluster_server_info = ",".join([tensor_to_addr(t) for t in comm_tensor])
        return cluster_server_info

    if local_rank == 0:
        local_ip = socket.gethostbyname(socket.gethostname())
        port = str(get_free_port(local_ip))
    else:
        local_ip = "0.0.0.0"
        port = "0"

    if not dist.is_initialized():  # stand-alone
        return local_ip + ":" + port

    gl_server_info = exchange_gl_server_info(
        addr_to_tensor(local_ip, port), group_size, group_rank
    )
    return gl_server_info