def get_nearest_sentence_ids()

in muss/mining/nn_search.py [0:0]


def get_nearest_sentence_ids(query_index, db_index, topk, nprobe, batch_size=1024, use_gpu=True):
    try:
        faiss.ParameterSpace().set_index_parameter(db_index, 'nprobe', nprobe)
    except RuntimeError as e:
        if 'could not set parameter nprobe' in str(e):
            pass
        else:
            raise e
    if use_gpu:
        db_index = faiss.index_cpu_to_all_gpus(db_index)
    all_distances = np.empty((query_index.ntotal, topk))
    all_sentence_ids = np.empty((query_index.ntotal, topk), dtype=int)
    for batch_idx in range((query_index.ntotal // batch_size) + 1):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, query_index.ntotal)
        actual_batch_size = end_idx - start_idx
        query_embeddings = query_index.reconstruct_n(start_idx, actual_batch_size)  # TODO: Do this in the background
        distances, sentence_ids = db_index.search(query_embeddings, topk)
        all_distances[start_idx:end_idx] = distances
        all_sentence_ids[start_idx:end_idx] = sentence_ids
    # If distances are sorted in descending order, we make them ascending instead for the following code to work
    if np.all(np.diff(all_distances) <= 0):
        # This is taylored for transforming cosine similarity into a pseudo-distance: the maximum cosine similarity is 1 (vectors are equal).
        # Hence distance = 1 - cosine will always be positive and will be be equal to 0 when vectors are equal.
        all_distances = 1 - all_distances
    return all_distances, all_sentence_ids.astype(int)