void OPQMatrix::train()

in faiss/VectorTransform.cpp [987:1200]


void OPQMatrix::train(Index::idx_t n, const float* x) {
    const float* x_in = x;

    x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);

    ScopeDeleter<float> del_x(x != x_in ? x : nullptr);

    // To support d_out > d_in, we pad input vectors with 0s to d_out
    size_t d = d_out <= d_in ? d_in : d_out;
    size_t d2 = d_out;

#if 0
    // what this test shows: the only way of getting bit-exact
    // reproducible results with sgeqrf and sgesvd seems to be forcing
    // single-threading.
    { // test repro
        std::vector<float> r (d * d);
        float * rotation = r.data();
        float_randn (rotation, d * d, 1234);
        printf("CS0: %016lx\n",
               ivec_checksum (128*128, (int*)rotation));
        matrix_qr (d, d, rotation);
        printf("CS1: %016lx\n",
               ivec_checksum (128*128, (int*)rotation));
        return;
    }
#endif

    if (verbose) {
        printf("OPQMatrix::train: training an OPQ rotation matrix "
               "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
               M,
               n,
               d_in,
               d_out);
    }

    std::vector<float> xtrain(n * d);
    // center x
    {
        std::vector<float> sum(d);
        const float* xi = x;
        for (size_t i = 0; i < n; i++) {
            for (int j = 0; j < d_in; j++)
                sum[j] += *xi++;
        }
        for (int i = 0; i < d; i++)
            sum[i] /= n;
        float* yi = xtrain.data();
        xi = x;
        for (size_t i = 0; i < n; i++) {
            for (int j = 0; j < d_in; j++)
                *yi++ = *xi++ - sum[j];
            yi += d - d_in;
        }
    }
    float* rotation;

    if (A.size() == 0) {
        A.resize(d * d);
        rotation = A.data();
        if (verbose)
            printf("  OPQMatrix::train: making random %zd*%zd rotation\n",
                   d,
                   d);
        float_randn(rotation, d * d, 1234);
        matrix_qr(d, d, rotation);
        // we use only the d * d2 upper part of the matrix
        A.resize(d * d2);
    } else {
        FAISS_THROW_IF_NOT(A.size() == d * d2);
        rotation = A.data();
    }

    std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
            tmp(d * d * 4);

    ProductQuantizer pq_default(d2, M, 8);
    ProductQuantizer& pq_regular = pq ? *pq : pq_default;
    std::vector<uint8_t> codes(pq_regular.code_size * n);

    double t0 = getmillisecs();
    for (int iter = 0; iter < niter; iter++) {
        { // torch.mm(xtrain, rotation:t())
            FINTEGER di = d, d2i = d2, ni = n;
            float zero = 0, one = 1;
            sgemm_("Transposed",
                   "Not transposed",
                   &d2i,
                   &ni,
                   &di,
                   &one,
                   rotation,
                   &di,
                   xtrain.data(),
                   &di,
                   &zero,
                   xproj.data(),
                   &d2i);
        }

        pq_regular.cp.max_points_per_centroid = 1000;
        pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
        pq_regular.verbose = verbose;
        pq_regular.train(n, xproj.data());

        if (verbose) {
            printf("    encode / decode\n");
        }
        if (pq_regular.assign_index) {
            pq_regular.compute_codes_with_assign_index(
                    xproj.data(), codes.data(), n);
        } else {
            pq_regular.compute_codes(xproj.data(), codes.data(), n);
        }
        pq_regular.decode(codes.data(), pq_recons.data(), n);

        float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;

        if (verbose)
            printf("    Iteration %d (%d PQ iterations):"
                   "%.3f s, obj=%g\n",
                   iter,
                   pq_regular.cp.niter,
                   (getmillisecs() - t0) / 1000.0,
                   pq_err);

        {
            float *u = tmp.data(), *vt = &tmp[d * d];
            float* sing_val = &tmp[2 * d * d];
            FINTEGER di = d, d2i = d2, ni = n;
            float one = 1, zero = 0;

            if (verbose) {
                printf("    X * recons\n");
            }
            // torch.mm(xtrain:t(), pq_recons)
            sgemm_("Not",
                   "Transposed",
                   &d2i,
                   &di,
                   &ni,
                   &one,
                   pq_recons.data(),
                   &d2i,
                   xtrain.data(),
                   &di,
                   &zero,
                   xxr.data(),
                   &d2i);

            FINTEGER lwork = -1, info = -1;
            float worksz;
            // workspace query
            sgesvd_("All",
                    "All",
                    &d2i,
                    &di,
                    xxr.data(),
                    &d2i,
                    sing_val,
                    vt,
                    &d2i,
                    u,
                    &di,
                    &worksz,
                    &lwork,
                    &info);

            lwork = int(worksz);
            std::vector<float> work(lwork);
            // u and vt swapped
            sgesvd_("All",
                    "All",
                    &d2i,
                    &di,
                    xxr.data(),
                    &d2i,
                    sing_val,
                    vt,
                    &d2i,
                    u,
                    &di,
                    work.data(),
                    &lwork,
                    &info);

            sgemm_("Transposed",
                   "Transposed",
                   &di,
                   &d2i,
                   &d2i,
                   &one,
                   u,
                   &di,
                   vt,
                   &d2i,
                   &zero,
                   rotation,
                   &di);
        }
        pq_regular.train_type = ProductQuantizer::Train_hot_start;
    }

    // revert A matrix
    if (d > d_in) {
        for (long i = 0; i < d_out; i++)
            memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
        A.resize(d_in * d_out);
    }

    is_trained = true;
    is_orthonormal = true;
}