in faiss/IndexIVF.cpp [385:648]
void IndexIVF::search_preassigned(
idx_t n,
const float* x,
idx_t k,
const idx_t* keys,
const float* coarse_dis,
float* distances,
idx_t* labels,
bool store_pairs,
const IVFSearchParameters* params,
IndexIVFStats* ivf_stats) const {
FAISS_THROW_IF_NOT(k > 0);
idx_t nprobe = params ? params->nprobe : this->nprobe;
nprobe = std::min((idx_t)nlist, nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);
idx_t max_codes = params ? params->max_codes : this->max_codes;
size_t nlistv = 0, ndis = 0, nheap = 0;
using HeapForIP = CMin<float, idx_t>;
using HeapForL2 = CMax<float, idx_t>;
bool interrupt = false;
std::mutex exception_mutex;
std::string exception_string;
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
bool do_parallel = omp_get_max_threads() >= 2 &&
(pmode == 0 ? false
: pmode == 3 ? n > 1
: pmode == 1 ? nprobe > 1
: nprobe * n > 1);
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
{
InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
ScopeDeleter1<InvertedListScanner> del(scanner);
/*****************************************************
* Depending on parallel_mode, there are two possible ways
* to organize the search. Here we define local functions
* that are in common between the two
******************************************************/
// intialize + reorder a result heap
auto init_result = [&](float* simi, idx_t* idxi) {
if (!do_heap_init)
return;
if (metric_type == METRIC_INNER_PRODUCT) {
heap_heapify<HeapForIP>(k, simi, idxi);
} else {
heap_heapify<HeapForL2>(k, simi, idxi);
}
};
auto add_local_results = [&](const float* local_dis,
const idx_t* local_idx,
float* simi,
idx_t* idxi) {
if (metric_type == METRIC_INNER_PRODUCT) {
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
} else {
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
}
};
auto reorder_result = [&](float* simi, idx_t* idxi) {
if (!do_heap_init)
return;
if (metric_type == METRIC_INNER_PRODUCT) {
heap_reorder<HeapForIP>(k, simi, idxi);
} else {
heap_reorder<HeapForL2>(k, simi, idxi);
}
};
// single list scan using the current scanner (with query
// set porperly) and storing results in simi and idxi
auto scan_one_list = [&](idx_t key,
float coarse_dis_i,
float* simi,
idx_t* idxi) {
if (key < 0) {
// not enough centroids for multiprobe
return (size_t)0;
}
FAISS_THROW_IF_NOT_FMT(
key < (idx_t)nlist,
"Invalid key=%" PRId64 " nlist=%zd\n",
key,
nlist);
size_t list_size = invlists->list_size(key);
// don't waste time on empty lists
if (list_size == 0) {
return (size_t)0;
}
scanner->set_list(key, coarse_dis_i);
nlistv++;
try {
InvertedLists::ScopedCodes scodes(invlists, key);
std::unique_ptr<InvertedLists::ScopedIds> sids;
const Index::idx_t* ids = nullptr;
if (!store_pairs) {
sids.reset(new InvertedLists::ScopedIds(invlists, key));
ids = sids->get();
}
nheap += scanner->scan_codes(
list_size, scodes.get(), ids, simi, idxi, k);
} catch (const std::exception& e) {
std::lock_guard<std::mutex> lock(exception_mutex);
exception_string =
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
interrupt = true;
return size_t(0);
}
return list_size;
};
/****************************************************
* Actual loops, depending on parallel_mode
****************************************************/
if (pmode == 0 || pmode == 3) {
#pragma omp for
for (idx_t i = 0; i < n; i++) {
if (interrupt) {
continue;
}
// loop over queries
scanner->set_query(x + i * d);
float* simi = distances + i * k;
idx_t* idxi = labels + i * k;
init_result(simi, idxi);
idx_t nscan = 0;
// loop over probes
for (size_t ik = 0; ik < nprobe; ik++) {
nscan += scan_one_list(
keys[i * nprobe + ik],
coarse_dis[i * nprobe + ik],
simi,
idxi);
if (max_codes && nscan >= max_codes) {
break;
}
}
ndis += nscan;
reorder_result(simi, idxi);
if (InterruptCallback::is_interrupted()) {
interrupt = true;
}
} // parallel for
} else if (pmode == 1) {
std::vector<idx_t> local_idx(k);
std::vector<float> local_dis(k);
for (size_t i = 0; i < n; i++) {
scanner->set_query(x + i * d);
init_result(local_dis.data(), local_idx.data());
#pragma omp for schedule(dynamic)
for (idx_t ik = 0; ik < nprobe; ik++) {
ndis += scan_one_list(
keys[i * nprobe + ik],
coarse_dis[i * nprobe + ik],
local_dis.data(),
local_idx.data());
// can't do the test on max_codes
}
// merge thread-local results
float* simi = distances + i * k;
idx_t* idxi = labels + i * k;
#pragma omp single
init_result(simi, idxi);
#pragma omp barrier
#pragma omp critical
{
add_local_results(
local_dis.data(), local_idx.data(), simi, idxi);
}
#pragma omp barrier
#pragma omp single
reorder_result(simi, idxi);
}
} else if (pmode == 2) {
std::vector<idx_t> local_idx(k);
std::vector<float> local_dis(k);
#pragma omp single
for (int64_t i = 0; i < n; i++) {
init_result(distances + i * k, labels + i * k);
}
#pragma omp for schedule(dynamic)
for (int64_t ij = 0; ij < n * nprobe; ij++) {
size_t i = ij / nprobe;
size_t j = ij % nprobe;
scanner->set_query(x + i * d);
init_result(local_dis.data(), local_idx.data());
ndis += scan_one_list(
keys[ij],
coarse_dis[ij],
local_dis.data(),
local_idx.data());
#pragma omp critical
{
add_local_results(
local_dis.data(),
local_idx.data(),
distances + i * k,
labels + i * k);
}
}
#pragma omp single
for (int64_t i = 0; i < n; i++) {
reorder_result(distances + i * k, labels + i * k);
}
} else {
FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
}
} // parallel section
if (interrupt) {
if (!exception_string.empty()) {
FAISS_THROW_FMT(
"search interrupted with: %s", exception_string.c_str());
} else {
FAISS_THROW_MSG("computation interrupted");
}
}
if (ivf_stats) {
ivf_stats->nq += n;
ivf_stats->nlist += nlistv;
ivf_stats->ndis += ndis;
ivf_stats->nheap_updates += nheap;
}
}