in muss/mining/nn_search.py [0:0]
def combine_results_over_db_indexes(intermediary_results_paths, offsets):
def combine_distances_and_ids(distances_list, sentence_ids_list):
topk = distances_list[0].shape[1]
assert all(distances.shape[1] == topk for distances in distances_list)
distances = np.concatenate(distances_list, axis=1)
sentence_ids = np.concatenate(sentence_ids_list, axis=1)
kept_indexes = np.argsort(distances, axis=1)[:, :topk]
return np.take_along_axis(distances, kept_indexes, axis=1), np.take_along_axis(
sentence_ids, kept_indexes, axis=1
)
for i, (results_path, offset) in tqdm(
list(enumerate(zip(intermediary_results_paths, offsets))), desc='Combine db indexes'
):
distances, sentence_ids = load_results(results_path)
sentence_ids += offset
if i == 0:
# No need to combine at first iteration
all_distances, all_sentence_ids = distances, sentence_ids
continue
all_distances, all_sentence_ids = combine_distances_and_ids(
[all_distances, distances], [all_sentence_ids, sentence_ids]
)
return all_distances, all_sentence_ids