in hugegraph-ml/src/hugegraph_ml/models/gatne.py [0:0]
def generate_pairs(all_walks, window_size, num_workers):
# for each node, choose the first neighbor and second neighbor of it to form pairs
# Get all worker processes
start_time = time.time()
print(f"We are generating pairs with {num_workers} cores.")
# Start all worker processes
pool = multiprocessing.Pool(processes=num_workers)
pairs = []
skip_window = window_size // 2
for layer_id, walks in enumerate(all_walks):
block_num = len(walks) // num_workers
if block_num > 0:
walks_list = [
walks[i * block_num : min((i + 1) * block_num, len(walks))]
for i in range(num_workers)
]
else:
walks_list = [walks]
tmp_result = pool.map(
partial(
generate_pairs_parallel,
skip_window=skip_window,
layer_id=layer_id,
),
walks_list,
)
pairs += reduce(lambda x, y: x + y, tmp_result)
pool.close()
end_time = time.time()
print(f"Generate pairs end, use {end_time - start_time}s.")
return np.array([list(pair) for pair in set(pairs)])