in faiss/IndexIVF.cpp [680:818]
void IndexIVF::range_search_preassigned(
idx_t nx,
const float* x,
float radius,
const idx_t* keys,
const float* coarse_dis,
RangeSearchResult* result,
bool store_pairs,
const IVFSearchParameters* params,
IndexIVFStats* stats) const {
idx_t nprobe = params ? params->nprobe : this->nprobe;
nprobe = std::min((idx_t)nlist, nprobe);
idx_t max_codes = params ? params->max_codes : this->max_codes;
size_t nlistv = 0, ndis = 0;
bool interrupt = false;
std::mutex exception_mutex;
std::string exception_string;
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
// don't start parallel section if single query
bool do_parallel = omp_get_max_threads() >= 2 &&
(pmode == 3 ? false
: pmode == 0 ? nx > 1
: pmode == 1 ? nprobe > 1
: nprobe * nx > 1);
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
{
RangeSearchPartialResult pres(result);
std::unique_ptr<InvertedListScanner> scanner(
get_InvertedListScanner(store_pairs));
FAISS_THROW_IF_NOT(scanner.get());
all_pres[omp_get_thread_num()] = &pres;
// prepare the list scanning function
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
idx_t key = keys[i * nprobe + ik]; /* select the list */
if (key < 0)
return;
FAISS_THROW_IF_NOT_FMT(
key < (idx_t)nlist,
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
key,
ik,
nlist);
const size_t list_size = invlists->list_size(key);
if (list_size == 0)
return;
try {
InvertedLists::ScopedCodes scodes(invlists, key);
InvertedLists::ScopedIds ids(invlists, key);
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
nlistv++;
ndis += list_size;
scanner->scan_codes_range(
list_size, scodes.get(), ids.get(), radius, qres);
} catch (const std::exception& e) {
std::lock_guard<std::mutex> lock(exception_mutex);
exception_string =
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
interrupt = true;
}
};
if (parallel_mode == 0) {
#pragma omp for
for (idx_t i = 0; i < nx; i++) {
scanner->set_query(x + i * d);
RangeQueryResult& qres = pres.new_result(i);
for (size_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
}
}
} else if (parallel_mode == 1) {
for (size_t i = 0; i < nx; i++) {
scanner->set_query(x + i * d);
RangeQueryResult& qres = pres.new_result(i);
#pragma omp for schedule(dynamic)
for (int64_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
}
}
} else if (parallel_mode == 2) {
std::vector<RangeQueryResult*> all_qres(nx);
RangeQueryResult* qres = nullptr;
#pragma omp for schedule(dynamic)
for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
idx_t i = iik / (idx_t)nprobe;
idx_t ik = iik % (idx_t)nprobe;
if (qres == nullptr || qres->qno != i) {
FAISS_ASSERT(!qres || i > qres->qno);
qres = &pres.new_result(i);
scanner->set_query(x + i * d);
}
scan_list_func(i, ik, *qres);
}
} else {
FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
}
if (parallel_mode == 0) {
pres.finalize();
} else {
#pragma omp barrier
#pragma omp single
RangeSearchPartialResult::merge(all_pres, false);
#pragma omp barrier
}
}
if (interrupt) {
if (!exception_string.empty()) {
FAISS_THROW_FMT(
"search interrupted with: %s", exception_string.c_str());
} else {
FAISS_THROW_MSG("computation interrupted");
}
}
if (stats) {
stats->nq += nx;
stats->nlist += nlistv;
stats->ndis += ndis;
}
}