def generate_pairs()

in hugegraph-ml/src/hugegraph_ml/models/gatne.py [0:0]


def generate_pairs(all_walks, window_size, num_workers):
    # for each node, choose the first neighbor and second neighbor of it to form pairs
    # Get all worker processes
    start_time = time.time()
    print(f"We are generating pairs with {num_workers} cores.")

    # Start all worker processes
    pool = multiprocessing.Pool(processes=num_workers)
    pairs = []
    skip_window = window_size // 2
    for layer_id, walks in enumerate(all_walks):
        block_num = len(walks) // num_workers
        if block_num > 0:
            walks_list = [
                walks[i * block_num : min((i + 1) * block_num, len(walks))]
                for i in range(num_workers)
            ]
        else:
            walks_list = [walks]
        tmp_result = pool.map(
            partial(
                generate_pairs_parallel,
                skip_window=skip_window,
                layer_id=layer_id,
            ),
            walks_list,
        )
        pairs += reduce(lambda x, y: x + y, tmp_result)

    pool.close()
    end_time = time.time()
    print(f"Generate pairs end, use {end_time - start_time}s.")
    return np.array([list(pair) for pair in set(pairs)])