void IndexHNSW2Level::search()

in faiss/IndexHNSW.cpp [1033:1172]


void IndexHNSW2Level::search(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels) const {
    FAISS_THROW_IF_NOT(k > 0);

    if (dynamic_cast<const Index2Layer*>(storage)) {
        IndexHNSW::search(n, x, k, distances, labels);

    } else { // "mixed" search
        size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;

        const IndexIVFPQ* index_ivfpq =
                dynamic_cast<const IndexIVFPQ*>(storage);

        int nprobe = index_ivfpq->nprobe;

        std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

        index_ivfpq->quantizer->search(
                n, x, nprobe, coarse_dis.get(), coarse_assign.get());

        index_ivfpq->search_preassigned(
                n,
                x,
                k,
                coarse_assign.get(),
                coarse_dis.get(),
                distances,
                labels,
                false);

#pragma omp parallel
        {
            VisitedTable vt(ntotal);
            DistanceComputer* dis = storage_distance_computer(storage);
            ScopeDeleter1<DistanceComputer> del(dis);

            int candidates_size = hnsw.upper_beam;
            MinimaxHeap candidates(candidates_size);

#pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
            for (idx_t i = 0; i < n; i++) {
                idx_t* idxi = labels + i * k;
                float* simi = distances + i * k;
                dis->set_query(x + i * d);

                // mark all inverted list elements as visited

                for (int j = 0; j < nprobe; j++) {
                    idx_t key = coarse_assign[j + i * nprobe];
                    if (key < 0)
                        break;
                    size_t list_length = index_ivfpq->get_list_size(key);
                    const idx_t* ids = index_ivfpq->invlists->get_ids(key);

                    for (int jj = 0; jj < list_length; jj++) {
                        vt.set(ids[jj]);
                    }
                }

                candidates.clear();
                // copy the upper_beam elements to candidates list

                int search_policy = 2;

                if (search_policy == 1) {
                    for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
                        if (idxi[j] < 0)
                            break;
                        candidates.push(idxi[j], simi[j]);
                        // search_from_candidates adds them back
                        idxi[j] = -1;
                        simi[j] = HUGE_VAL;
                    }

                    // reorder from sorted to heap
                    maxheap_heapify(k, simi, idxi, simi, idxi, k);

                    HNSWStats search_stats;
                    hnsw.search_from_candidates(
                            *dis,
                            k,
                            idxi,
                            simi,
                            candidates,
                            vt,
                            search_stats,
                            0,
                            k);
                    n1 += search_stats.n1;
                    n2 += search_stats.n2;
                    n3 += search_stats.n3;
                    ndis += search_stats.ndis;
                    nreorder += search_stats.nreorder;

                    vt.advance();

                } else if (search_policy == 2) {
                    for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
                        if (idxi[j] < 0)
                            break;
                        candidates.push(idxi[j], simi[j]);
                    }

                    // reorder from sorted to heap
                    maxheap_heapify(k, simi, idxi, simi, idxi, k);

                    HNSWStats search_stats;
                    search_from_candidates_2(
                            hnsw,
                            *dis,
                            k,
                            idxi,
                            simi,
                            candidates,
                            vt,
                            search_stats,
                            0,
                            k);
                    n1 += search_stats.n1;
                    n2 += search_stats.n2;
                    n3 += search_stats.n3;
                    ndis += search_stats.ndis;
                    nreorder += search_stats.nreorder;

                    vt.advance();
                    vt.advance();
                }

                maxheap_reorder(k, simi, idxi);
            }
        }

        hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
    }
}