in faiss/VectorTransform.cpp [725:863]
void ITQMatrix::train(Index::idx_t n, const float* xf) {
size_t d = d_in;
std::vector<double> rotation(d * d);
if (init_rotation.size() == d * d) {
memcpy(rotation.data(),
init_rotation.data(),
d * d * sizeof(rotation[0]));
} else {
RandomRotationMatrix rrot(d, d);
rrot.init(seed);
for (size_t i = 0; i < d * d; i++) {
rotation[i] = rrot.A[i];
}
}
std::vector<double> x(n * d);
for (size_t i = 0; i < n * d; i++) {
x[i] = xf[i];
}
std::vector<double> rotated_x(n * d), cov_mat(d * d);
std::vector<double> u(d * d), vt(d * d), singvals(d);
for (int i = 0; i < max_iter; i++) {
print_if_verbose("rotation", rotation, d, d);
{ // rotated_data = np.dot(training_data, rotation)
FINTEGER di = d, ni = n;
double one = 1, zero = 0;
dgemm_("N",
"N",
&di,
&ni,
&di,
&one,
rotation.data(),
&di,
x.data(),
&di,
&zero,
rotated_x.data(),
&di);
}
print_if_verbose("rotated_x", rotated_x, n, d);
// binarize
for (size_t j = 0; j < n * d; j++) {
rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
}
// covariance matrix
{ // rotated_data = np.dot(training_data, rotation)
FINTEGER di = d, ni = n;
double one = 1, zero = 0;
dgemm_("N",
"T",
&di,
&di,
&ni,
&one,
rotated_x.data(),
&di,
x.data(),
&di,
&zero,
cov_mat.data(),
&di);
}
print_if_verbose("cov_mat", cov_mat, d, d);
// SVD
{
FINTEGER di = d;
FINTEGER lwork = -1, info;
double lwork1;
// workspace query
dgesvd_("A",
"A",
&di,
&di,
cov_mat.data(),
&di,
singvals.data(),
u.data(),
&di,
vt.data(),
&di,
&lwork1,
&lwork,
&info);
FAISS_THROW_IF_NOT(info == 0);
lwork = size_t(lwork1);
std::vector<double> work(lwork);
dgesvd_("A",
"A",
&di,
&di,
cov_mat.data(),
&di,
singvals.data(),
u.data(),
&di,
vt.data(),
&di,
work.data(),
&lwork,
&info);
FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
}
print_if_verbose("u", u, d, d);
print_if_verbose("vt", vt, d, d);
// update rotation
{
FINTEGER di = d;
double one = 1, zero = 0;
dgemm_("N",
"T",
&di,
&di,
&di,
&one,
u.data(),
&di,
vt.data(),
&di,
&zero,
rotation.data(),
&di);
}
print_if_verbose("final rot", rotation, d, d);
}
A.resize(d * d);
for (size_t i = 0; i < d; i++) {
for (size_t j = 0; j < d; j++) {
A[i + d * j] = rotation[j + d * i];
}
}
is_trained = true;
}