def combine_results_over_db_indexes()

in muss/mining/nn_search.py [0:0]


def combine_results_over_db_indexes(intermediary_results_paths, offsets):
    def combine_distances_and_ids(distances_list, sentence_ids_list):
        topk = distances_list[0].shape[1]
        assert all(distances.shape[1] == topk for distances in distances_list)
        distances = np.concatenate(distances_list, axis=1)
        sentence_ids = np.concatenate(sentence_ids_list, axis=1)
        kept_indexes = np.argsort(distances, axis=1)[:, :topk]
        return np.take_along_axis(distances, kept_indexes, axis=1), np.take_along_axis(
            sentence_ids, kept_indexes, axis=1
        )

    for i, (results_path, offset) in tqdm(
        list(enumerate(zip(intermediary_results_paths, offsets))), desc='Combine db indexes'
    ):
        distances, sentence_ids = load_results(results_path)
        sentence_ids += offset
        if i == 0:
            # No need to combine at first iteration
            all_distances, all_sentence_ids = distances, sentence_ids
            continue
        all_distances, all_sentence_ids = combine_distances_and_ids(
            [all_distances, distances], [all_sentence_ids, sentence_ids]
        )
    return all_distances, all_sentence_ids