void ResidualQuantizer::train()

in faiss/impl/ResidualQuantizer.cpp [237:415]


void ResidualQuantizer::train(size_t n, const float* x) {
    codebooks.resize(d * codebook_offsets.back());

    if (verbose) {
        printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n",
               M,
               n,
               size_t(d));
    }

    int cur_beam_size = 1;
    std::vector<float> residuals(x, x + n * d);
    std::vector<int32_t> codes;
    std::vector<float> distances;
    double t0 = getmillisecs();
    double clustering_time = 0;

    for (int m = 0; m < M; m++) {
        int K = 1 << nbits[m];

        // on which residuals to train
        std::vector<float>& train_residuals = residuals;
        std::vector<float> residuals1;
        if (train_type & Train_top_beam) {
            residuals1.resize(n * d);
            for (size_t j = 0; j < n; j++) {
                memcpy(residuals1.data() + j * d,
                       residuals.data() + j * d * cur_beam_size,
                       sizeof(residuals[0]) * d);
            }
            train_residuals = residuals1;
        }
        std::vector<float> codebooks;
        float obj = 0;

        std::unique_ptr<Index> assign_index;
        if (assign_index_factory) {
            assign_index.reset((*assign_index_factory)(d));
        } else {
            assign_index.reset(new IndexFlatL2(d));
        }

        double t1 = getmillisecs();

        if (!(train_type & Train_progressive_dim)) { // regular kmeans
            Clustering clus(d, K, cp);
            clus.train(
                    train_residuals.size() / d,
                    train_residuals.data(),
                    *assign_index.get());
            codebooks.swap(clus.centroids);
            assign_index->reset();
            obj = clus.iteration_stats.back().obj;
        } else { // progressive dim clustering
            ProgressiveDimClustering clus(d, K, cp);
            ProgressiveDimIndexFactory default_fac;
            clus.train(
                    train_residuals.size() / d,
                    train_residuals.data(),
                    assign_index_factory ? *assign_index_factory : default_fac);
            codebooks.swap(clus.centroids);
            obj = clus.iteration_stats.back().obj;
        }
        clustering_time += (getmillisecs() - t1) / 1000;

        memcpy(this->codebooks.data() + codebook_offsets[m] * d,
               codebooks.data(),
               codebooks.size() * sizeof(codebooks[0]));

        // quantize using the new codebooks

        int new_beam_size = std::min(cur_beam_size * K, max_beam_size);
        std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
        std::vector<float> new_residuals(n * new_beam_size * d);
        std::vector<float> new_distances(n * new_beam_size);

        size_t bs;
        { // determine batch size
            size_t mem = memory_per_point();
            if (n > 1 && mem * n > max_mem_distances) {
                // then split queries to reduce temp memory
                bs = std::max(max_mem_distances / mem, size_t(1));
            } else {
                bs = n;
            }
        }

        for (size_t i0 = 0; i0 < n; i0 += bs) {
            size_t i1 = std::min(i0 + bs, n);

            /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
                i0, i1, K, assign_index->ntotal); */

            beam_search_encode_step(
                    d,
                    K,
                    codebooks.data(),
                    i1 - i0,
                    cur_beam_size,
                    residuals.data() + i0 * cur_beam_size * d,
                    m,
                    codes.data() + i0 * cur_beam_size * m,
                    new_beam_size,
                    new_codes.data() + i0 * new_beam_size * (m + 1),
                    new_residuals.data() + i0 * new_beam_size * d,
                    new_distances.data() + i0 * new_beam_size,
                    assign_index.get());
        }
        codes.swap(new_codes);
        residuals.swap(new_residuals);
        distances.swap(new_distances);

        float sum_distances = 0;
        for (int j = 0; j < distances.size(); j++) {
            sum_distances += distances[j];
        }

        if (verbose) {
            printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
                   "total distance %g, beam_size %d->%d (batch size %zd)\n",
                   (getmillisecs() - t0) / 1000,
                   clustering_time,
                   m,
                   int(nbits[m]),
                   obj,
                   sum_distances,
                   cur_beam_size,
                   new_beam_size,
                   bs);
        }
        cur_beam_size = new_beam_size;
    }

    is_trained = true;

    if (train_type & Train_refine_codebook) {
        for (int iter = 0; iter < niter_codebook_refine; iter++) {
            if (verbose) {
                printf("re-estimating the codebooks to minimize "
                       "quantization errors (iter %d).\n",
                       iter);
            }
            retrain_AQ_codebook(n, x);
        }
    }

    // find min and max norms
    std::vector<float> norms(n);

    for (size_t i = 0; i < n; i++) {
        norms[i] = fvec_L2sqr(
                x + i * d, residuals.data() + i * cur_beam_size * d, d);
    }

    norm_min = HUGE_VALF;
    norm_max = -HUGE_VALF;
    for (idx_t i = 0; i < n; i++) {
        if (norms[i] < norm_min) {
            norm_min = norms[i];
        }
        if (norms[i] > norm_max) {
            norm_max = norms[i];
        }
    }

    if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
        size_t k = (1 << 8);
        if (search_type == ST_norm_cqint4) {
            k = (1 << 4);
        }
        Clustering1D clus(k);
        clus.train_exact(n, norms.data());
        qnorm.add(clus.k, clus.centroids.data());
    }

    if (!(train_type & Skip_codebook_tables)) {
        compute_codebook_tables();
    }
}