def search_with_capped_res()

in isc/descriptor_matching.py [0:0]


def search_with_capped_res(xq, xb, num_results, metric=faiss.METRIC_L2):
    """
    Searches xq into xb, with a maximum total number of results
    """
    index = faiss.IndexFlat(xb.shape[1], metric)
    index.add(xb)
    # logging.basicConfig()
    # logging.getLogger(exhaustive_search.__name__).setLevel(logging.DEBUG)

    if metric == faiss.METRIC_INNER_PRODUCT:
        # this is a very ugly hack because contrib.exhaustive_search does
        # not support IP search correctly. Do not use in a multithreaded env.
        apply_maxres_saved = exhaustive_search.apply_maxres
        exhaustive_search.apply_maxres = apply_maxres_IP

    radius, lims, dis, ids = exhaustive_search.range_search_max_results(
        index, query_iterator(xq),
        1e10 if metric == faiss.METRIC_L2 else -1e10,      # initial radius does not filter anything
        max_results=2 * num_results,
        min_results=num_results,
        ngpu=-1   # use GPU if available
    )

    if metric == faiss.METRIC_INNER_PRODUCT:
        exhaustive_search.apply_maxres = apply_maxres_saved

    n = len(dis)
    nq = len(xq)
    if n > num_results:
        # crop to num_results exactly
        if metric == faiss.METRIC_L2:
            o = dis.argpartition(num_results)[:num_results]
        else:
            o = dis.argpartition(len(dis) - num_results)[-num_results:]
        mask = np.zeros(n, bool)
        mask[o] = True
        new_dis = dis[mask]
        new_ids = ids[mask]
        nres = [0] + [
            mask[lims[i] : lims[i + 1]].sum()
            for i in range(nq)
        ]
        new_lims = np.cumsum(nres)
        lims, dis, ids = new_lims, new_dis, new_ids

    return lims, dis, ids