void ParameterSpace::explore()

in faiss/AutoTune.cpp [584:721]


void ParameterSpace::explore(
        Index* index,
        size_t nq,
        const float* xq,
        const AutoTuneCriterion& crit,
        OperatingPoints* ops) const {
    FAISS_THROW_IF_NOT_MSG(
            nq == crit.nq, "criterion does not have the same nb of queries");

    size_t n_comb = n_combinations();

    if (n_experiments == 0) {
        for (size_t cno = 0; cno < n_comb; cno++) {
            set_index_parameters(index, cno);
            std::vector<Index::idx_t> I(nq * crit.nnn);
            std::vector<float> D(nq * crit.nnn);

            double t0 = getmillisecs();
            index->search(nq, xq, crit.nnn, D.data(), I.data());
            double t_search = (getmillisecs() - t0) / 1e3;

            double perf = crit.evaluate(D.data(), I.data());

            bool keep = ops->add(perf, t_search, combination_name(cno), cno);

            if (verbose)
                printf("  %zd/%zd: %s perf=%.3f t=%.3f s %s\n",
                       cno,
                       n_comb,
                       combination_name(cno).c_str(),
                       perf,
                       t_search,
                       keep ? "*" : "");
        }
        return;
    }

    int n_exp = n_experiments;

    if (n_exp > n_comb)
        n_exp = n_comb;
    FAISS_THROW_IF_NOT(n_comb == 1 || n_exp > 2);
    std::vector<int> perm(n_comb);
    // make sure the slowest and fastest experiment are run
    perm[0] = 0;
    if (n_comb > 1) {
        perm[1] = n_comb - 1;
        rand_perm(&perm[2], n_comb - 2, 1234);
        for (int i = 2; i < perm.size(); i++)
            perm[i]++;
    }

    for (size_t xp = 0; xp < n_exp; xp++) {
        size_t cno = perm[xp];

        if (verbose)
            printf("  %zd/%d: cno=%zd %s ",
                   xp,
                   n_exp,
                   cno,
                   combination_name(cno).c_str());

        {
            double lower_bound_t = 0.0;
            double upper_bound_perf = 1.0;
            for (int i = 0; i < ops->all_pts.size(); i++) {
                update_bounds(
                        cno,
                        ops->all_pts[i],
                        &upper_bound_perf,
                        &lower_bound_t);
            }
            double best_t = ops->t_for_perf(upper_bound_perf);
            if (verbose)
                printf("bounds [perf<=%.3f t>=%.3f] %s",
                       upper_bound_perf,
                       lower_bound_t,
                       best_t <= lower_bound_t ? "skip\n" : "");
            if (best_t <= lower_bound_t)
                continue;
        }

        set_index_parameters(index, cno);
        std::vector<Index::idx_t> I(nq * crit.nnn);
        std::vector<float> D(nq * crit.nnn);

        double t0 = getmillisecs();

        int nrun = 0;
        double t_search;

        do {
            if (thread_over_batches) {
#pragma omp parallel for
                for (Index::idx_t q0 = 0; q0 < nq; q0 += batchsize) {
                    size_t q1 = q0 + batchsize;
                    if (q1 > nq)
                        q1 = nq;
                    index->search(
                            q1 - q0,
                            xq + q0 * index->d,
                            crit.nnn,
                            D.data() + q0 * crit.nnn,
                            I.data() + q0 * crit.nnn);
                }
            } else {
                for (size_t q0 = 0; q0 < nq; q0 += batchsize) {
                    size_t q1 = q0 + batchsize;
                    if (q1 > nq)
                        q1 = nq;
                    index->search(
                            q1 - q0,
                            xq + q0 * index->d,
                            crit.nnn,
                            D.data() + q0 * crit.nnn,
                            I.data() + q0 * crit.nnn);
                }
            }
            nrun++;
            t_search = (getmillisecs() - t0) / 1e3;

        } while (t_search < min_test_duration);

        t_search /= nrun;

        double perf = crit.evaluate(D.data(), I.data());

        bool keep = ops->add(perf, t_search, combination_name(cno), cno);

        if (verbose)
            printf(" perf %.3f t %.3f (%d %s) %s\n",
                   perf,
                   t_search,
                   nrun,
                   nrun >= 2 ? "runs" : "run",
                   keep ? "*" : "");
    }
}