void IndexPQ::search_core_polysemous()

in faiss/IndexPQ.cpp [294:411]


void IndexPQ::search_core_polysemous(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels) const {
    FAISS_THROW_IF_NOT(k > 0);

    FAISS_THROW_IF_NOT(pq.nbits == 8);

    // PQ distance tables
    float* dis_tables = new float[n * pq.ksub * pq.M];
    ScopeDeleter<float> del(dis_tables);
    pq.compute_distance_tables(n, x, dis_tables);

    // Hamming embedding queries
    uint8_t* q_codes = new uint8_t[n * pq.code_size];
    ScopeDeleter<uint8_t> del2(q_codes);

    if (false) {
        pq.compute_codes(x, q_codes, n);
    } else {
#pragma omp parallel for
        for (idx_t qi = 0; qi < n; qi++) {
            pq.compute_code_from_distance_table(
                    dis_tables + qi * pq.M * pq.ksub,
                    q_codes + qi * pq.code_size);
        }
    }

    size_t n_pass = 0;

#pragma omp parallel for reduction(+ : n_pass)
    for (idx_t qi = 0; qi < n; qi++) {
        const uint8_t* q_code = q_codes + qi * pq.code_size;

        const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;

        int64_t* heap_ids = labels + qi * k;
        float* heap_dis = distances + qi * k;
        maxheap_heapify(k, heap_dis, heap_ids);

        if (search_type == ST_polysemous) {
            switch (pq.code_size) {
                case 4:
                    n_pass += polysemous_inner_loop<HammingComputer4>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 8:
                    n_pass += polysemous_inner_loop<HammingComputer8>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 16:
                    n_pass += polysemous_inner_loop<HammingComputer16>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 32:
                    n_pass += polysemous_inner_loop<HammingComputer32>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 20:
                    n_pass += polysemous_inner_loop<HammingComputer20>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                default:
                    if (pq.code_size % 4 == 0) {
                        n_pass += polysemous_inner_loop<HammingComputerDefault>(
                                *this,
                                dis_table_qi,
                                q_code,
                                k,
                                heap_dis,
                                heap_ids);
                    } else {
                        FAISS_THROW_FMT(
                                "code size %zd not supported for polysemous",
                                pq.code_size);
                    }
                    break;
            }
        } else {
            switch (pq.code_size) {
                case 8:
                    n_pass += polysemous_inner_loop<GenHammingComputer8>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 16:
                    n_pass += polysemous_inner_loop<GenHammingComputer16>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                case 32:
                    n_pass += polysemous_inner_loop<GenHammingComputer32>(
                            *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
                    break;
                default:
                    if (pq.code_size % 8 == 0) {
                        n_pass += polysemous_inner_loop<GenHammingComputerM8>(
                                *this,
                                dis_table_qi,
                                q_code,
                                k,
                                heap_dis,
                                heap_ids);
                    } else {
                        FAISS_THROW_FMT(
                                "code size %zd not supported for polysemous",
                                pq.code_size);
                    }
                    break;
            }
        }
        maxheap_reorder(k, heap_dis, heap_ids);
    }

    indexPQ_stats.nq += n;
    indexPQ_stats.ncode += n * ntotal;
    indexPQ_stats.n_hamming_pass += n_pass;
}