in faiss/IndexPQFastScan.cpp [191:310]
void IndexPQFastScan::search_dispatch_implem(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const {
using Cfloat = typename std::conditional<
is_max,
CMax<float, int64_t>,
CMin<float, int64_t>>::type;
using C = typename std::
conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
if (n == 0) {
return;
}
// actual implementation used
int impl = implem;
if (impl == 0) {
if (bbs == 32) {
impl = 12;
} else {
impl = 14;
}
if (k > 20) {
impl++;
}
}
if (implem == 1) {
FAISS_THROW_IF_NOT(orig_codes);
FAISS_THROW_IF_NOT(is_max);
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
pq.search(x, n, orig_codes, ntotal, &res, true);
} else if (implem == 2 || implem == 3 || implem == 4) {
FAISS_THROW_IF_NOT(orig_codes);
size_t dim12 = pq.ksub * pq.M;
std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
if (is_max) {
pq.compute_distance_tables(n, x, dis_tables.get());
} else {
pq.compute_inner_prod_tables(n, x, dis_tables.get());
}
std::vector<float> normalizers(n * 2);
if (implem == 2) {
// default float
} else if (implem == 3 || implem == 4) {
for (uint64_t i = 0; i < n; i++) {
round_uint8_per_column(
dis_tables.get() + i * dim12,
pq.M,
pq.ksub,
&normalizers[2 * i],
&normalizers[2 * i + 1]);
}
}
for (int64_t i = 0; i < n; i++) {
int64_t* heap_ids = labels + i * k;
float* heap_dis = distances + i * k;
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
pq_estimators_from_tables_generic<Cfloat>(
pq,
pq.nbits,
orig_codes,
ntotal,
dis_tables.get() + i * dim12,
k,
heap_dis,
heap_ids);
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
if (implem == 4) {
float a = normalizers[2 * i];
float b = normalizers[2 * i + 1];
for (int j = 0; j < k; j++) {
heap_dis[j] = heap_dis[j] / a + b;
}
}
}
} else if (impl >= 12 && impl <= 15) {
FAISS_THROW_IF_NOT(ntotal < INT_MAX);
int nt = std::min(omp_get_max_threads(), int(n));
if (nt < 2) {
if (impl == 12 || impl == 13) {
search_implem_12<C>(n, x, k, distances, labels, impl);
} else {
search_implem_14<C>(n, x, k, distances, labels, impl);
}
} else {
// explicitly slice over threads
#pragma omp parallel for num_threads(nt)
for (int slice = 0; slice < nt; slice++) {
idx_t i0 = n * slice / nt;
idx_t i1 = n * (slice + 1) / nt;
float* dis_i = distances + i0 * k;
idx_t* lab_i = labels + i0 * k;
if (impl == 12 || impl == 13) {
search_implem_12<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
} else {
search_implem_14<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
}
}
}
} else {
FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
}
}