in isc/descriptor_matching.py [0:0]
def search_with_capped_res(xq, xb, num_results, metric=faiss.METRIC_L2):
"""
Searches xq into xb, with a maximum total number of results
"""
index = faiss.IndexFlat(xb.shape[1], metric)
index.add(xb)
# logging.basicConfig()
# logging.getLogger(exhaustive_search.__name__).setLevel(logging.DEBUG)
if metric == faiss.METRIC_INNER_PRODUCT:
# this is a very ugly hack because contrib.exhaustive_search does
# not support IP search correctly. Do not use in a multithreaded env.
apply_maxres_saved = exhaustive_search.apply_maxres
exhaustive_search.apply_maxres = apply_maxres_IP
radius, lims, dis, ids = exhaustive_search.range_search_max_results(
index, query_iterator(xq),
1e10 if metric == faiss.METRIC_L2 else -1e10, # initial radius does not filter anything
max_results=2 * num_results,
min_results=num_results,
ngpu=-1 # use GPU if available
)
if metric == faiss.METRIC_INNER_PRODUCT:
exhaustive_search.apply_maxres = apply_maxres_saved
n = len(dis)
nq = len(xq)
if n > num_results:
# crop to num_results exactly
if metric == faiss.METRIC_L2:
o = dis.argpartition(num_results)[:num_results]
else:
o = dis.argpartition(len(dis) - num_results)[-num_results:]
mask = np.zeros(n, bool)
mask[o] = True
new_dis = dis[mask]
new_ids = ids[mask]
nres = [0] + [
mask[lims[i] : lims[i + 1]].sum()
for i in range(nq)
]
new_lims = np.cumsum(nres)
lims, dis, ids = new_lims, new_dis, new_ids
return lims, dis, ids