void ITQMatrix::train()

in faiss/VectorTransform.cpp [725:863]


void ITQMatrix::train(Index::idx_t n, const float* xf) {
    size_t d = d_in;
    std::vector<double> rotation(d * d);

    if (init_rotation.size() == d * d) {
        memcpy(rotation.data(),
               init_rotation.data(),
               d * d * sizeof(rotation[0]));
    } else {
        RandomRotationMatrix rrot(d, d);
        rrot.init(seed);
        for (size_t i = 0; i < d * d; i++) {
            rotation[i] = rrot.A[i];
        }
    }

    std::vector<double> x(n * d);

    for (size_t i = 0; i < n * d; i++) {
        x[i] = xf[i];
    }

    std::vector<double> rotated_x(n * d), cov_mat(d * d);
    std::vector<double> u(d * d), vt(d * d), singvals(d);

    for (int i = 0; i < max_iter; i++) {
        print_if_verbose("rotation", rotation, d, d);
        { // rotated_data = np.dot(training_data, rotation)
            FINTEGER di = d, ni = n;
            double one = 1, zero = 0;
            dgemm_("N",
                   "N",
                   &di,
                   &ni,
                   &di,
                   &one,
                   rotation.data(),
                   &di,
                   x.data(),
                   &di,
                   &zero,
                   rotated_x.data(),
                   &di);
        }
        print_if_verbose("rotated_x", rotated_x, n, d);
        // binarize
        for (size_t j = 0; j < n * d; j++) {
            rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
        }
        // covariance matrix
        { // rotated_data = np.dot(training_data, rotation)
            FINTEGER di = d, ni = n;
            double one = 1, zero = 0;
            dgemm_("N",
                   "T",
                   &di,
                   &di,
                   &ni,
                   &one,
                   rotated_x.data(),
                   &di,
                   x.data(),
                   &di,
                   &zero,
                   cov_mat.data(),
                   &di);
        }
        print_if_verbose("cov_mat", cov_mat, d, d);
        // SVD
        {
            FINTEGER di = d;
            FINTEGER lwork = -1, info;
            double lwork1;

            // workspace query
            dgesvd_("A",
                    "A",
                    &di,
                    &di,
                    cov_mat.data(),
                    &di,
                    singvals.data(),
                    u.data(),
                    &di,
                    vt.data(),
                    &di,
                    &lwork1,
                    &lwork,
                    &info);

            FAISS_THROW_IF_NOT(info == 0);
            lwork = size_t(lwork1);
            std::vector<double> work(lwork);
            dgesvd_("A",
                    "A",
                    &di,
                    &di,
                    cov_mat.data(),
                    &di,
                    singvals.data(),
                    u.data(),
                    &di,
                    vt.data(),
                    &di,
                    work.data(),
                    &lwork,
                    &info);
            FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
        }
        print_if_verbose("u", u, d, d);
        print_if_verbose("vt", vt, d, d);
        // update rotation
        {
            FINTEGER di = d;
            double one = 1, zero = 0;
            dgemm_("N",
                   "T",
                   &di,
                   &di,
                   &di,
                   &one,
                   u.data(),
                   &di,
                   vt.data(),
                   &di,
                   &zero,
                   rotation.data(),
                   &di);
        }
        print_if_verbose("final rot", rotation, d, d);
    }
    A.resize(d * d);
    for (size_t i = 0; i < d; i++) {
        for (size_t j = 0; j < d; j++) {
            A[i + d * j] = rotation[j + d * i];
        }
    }
    is_trained = true;
}