void IndexPQFastScan::search_dispatch_implem()

in faiss/IndexPQFastScan.cpp [191:310]


void IndexPQFastScan::search_dispatch_implem(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels) const {
    using Cfloat = typename std::conditional<
            is_max,
            CMax<float, int64_t>,
            CMin<float, int64_t>>::type;

    using C = typename std::
            conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;

    if (n == 0) {
        return;
    }

    // actual implementation used
    int impl = implem;

    if (impl == 0) {
        if (bbs == 32) {
            impl = 12;
        } else {
            impl = 14;
        }
        if (k > 20) {
            impl++;
        }
    }

    if (implem == 1) {
        FAISS_THROW_IF_NOT(orig_codes);
        FAISS_THROW_IF_NOT(is_max);
        float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
        pq.search(x, n, orig_codes, ntotal, &res, true);
    } else if (implem == 2 || implem == 3 || implem == 4) {
        FAISS_THROW_IF_NOT(orig_codes);

        size_t dim12 = pq.ksub * pq.M;
        std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
        if (is_max) {
            pq.compute_distance_tables(n, x, dis_tables.get());
        } else {
            pq.compute_inner_prod_tables(n, x, dis_tables.get());
        }

        std::vector<float> normalizers(n * 2);

        if (implem == 2) {
            // default float
        } else if (implem == 3 || implem == 4) {
            for (uint64_t i = 0; i < n; i++) {
                round_uint8_per_column(
                        dis_tables.get() + i * dim12,
                        pq.M,
                        pq.ksub,
                        &normalizers[2 * i],
                        &normalizers[2 * i + 1]);
            }
        }

        for (int64_t i = 0; i < n; i++) {
            int64_t* heap_ids = labels + i * k;
            float* heap_dis = distances + i * k;

            heap_heapify<Cfloat>(k, heap_dis, heap_ids);

            pq_estimators_from_tables_generic<Cfloat>(
                    pq,
                    pq.nbits,
                    orig_codes,
                    ntotal,
                    dis_tables.get() + i * dim12,
                    k,
                    heap_dis,
                    heap_ids);

            heap_reorder<Cfloat>(k, heap_dis, heap_ids);

            if (implem == 4) {
                float a = normalizers[2 * i];
                float b = normalizers[2 * i + 1];

                for (int j = 0; j < k; j++) {
                    heap_dis[j] = heap_dis[j] / a + b;
                }
            }
        }
    } else if (impl >= 12 && impl <= 15) {
        FAISS_THROW_IF_NOT(ntotal < INT_MAX);
        int nt = std::min(omp_get_max_threads(), int(n));
        if (nt < 2) {
            if (impl == 12 || impl == 13) {
                search_implem_12<C>(n, x, k, distances, labels, impl);
            } else {
                search_implem_14<C>(n, x, k, distances, labels, impl);
            }
        } else {
            // explicitly slice over threads
#pragma omp parallel for num_threads(nt)
            for (int slice = 0; slice < nt; slice++) {
                idx_t i0 = n * slice / nt;
                idx_t i1 = n * (slice + 1) / nt;
                float* dis_i = distances + i0 * k;
                idx_t* lab_i = labels + i0 * k;
                if (impl == 12 || impl == 13) {
                    search_implem_12<C>(
                            i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
                } else {
                    search_implem_14<C>(
                            i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
                }
            }
        }
    } else {
        FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
    }
}