void IndexIVFPQFastScan::search_dispatch_implem()

in faiss/IndexIVFPQFastScan.cpp [532:653]


void IndexIVFPQFastScan::search_dispatch_implem(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels) const {
    using Cfloat = typename std::conditional<
            is_max,
            CMax<float, int64_t>,
            CMin<float, int64_t>>::type;

    using C = typename std::conditional<
            is_max,
            CMax<uint16_t, int64_t>,
            CMin<uint16_t, int64_t>>::type;

    if (n == 0) {
        return;
    }

    // actual implementation used
    int impl = implem;

    if (impl == 0) {
        if (bbs == 32) {
            impl = 12;
        } else {
            impl = 10;
        }
        if (k > 20) {
            impl++;
        }
    }

    if (impl == 1) {
        search_implem_1<Cfloat>(n, x, k, distances, labels);
    } else if (impl == 2) {
        search_implem_2<C>(n, x, k, distances, labels);

    } else if (impl >= 10 && impl <= 13) {
        size_t ndis = 0, nlist_visited = 0;

        if (n < 2) {
            if (impl == 12 || impl == 13) {
                search_implem_12<C>(
                        n,
                        x,
                        k,
                        distances,
                        labels,
                        impl,
                        &ndis,
                        &nlist_visited);
            } else {
                search_implem_10<C>(
                        n,
                        x,
                        k,
                        distances,
                        labels,
                        impl,
                        &ndis,
                        &nlist_visited);
            }
        } else {
            // explicitly slice over threads
            int nslice;
            if (n <= omp_get_max_threads()) {
                nslice = n;
            } else if (by_residual && metric_type == METRIC_L2) {
                // make sure we don't make too big LUT tables
                size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
                        (sizeof(float) + sizeof(uint8_t));

                size_t max_lut_size = precomputed_table_max_bytes;
                // how many queries we can handle within mem budget
                size_t nq_ok =
                        std::max(max_lut_size / lut_size_per_query, size_t(1));
                nslice =
                        roundup(std::max(size_t(n / nq_ok), size_t(1)),
                                omp_get_max_threads());
            } else {
                // LUTs unlikely to be a limiting factor
                nslice = omp_get_max_threads();
            }

#pragma omp parallel for reduction(+ : ndis, nlist_visited)
            for (int slice = 0; slice < nslice; slice++) {
                idx_t i0 = n * slice / nslice;
                idx_t i1 = n * (slice + 1) / nslice;
                float* dis_i = distances + i0 * k;
                idx_t* lab_i = labels + i0 * k;
                if (impl == 12 || impl == 13) {
                    search_implem_12<C>(
                            i1 - i0,
                            x + i0 * d,
                            k,
                            dis_i,
                            lab_i,
                            impl,
                            &ndis,
                            &nlist_visited);
                } else {
                    search_implem_10<C>(
                            i1 - i0,
                            x + i0 * d,
                            k,
                            dis_i,
                            lab_i,
                            impl,
                            &ndis,
                            &nlist_visited);
                }
            }
        }
        indexIVF_stats.nq += n;
        indexIVF_stats.ndis += ndis;
        indexIVF_stats.nlist += nlist_visited;
    } else {
        FAISS_THROW_FMT("implem %d does not exist", implem);
    }
}