in faiss/impl/ResidualQuantizer.cpp [237:415]
void ResidualQuantizer::train(size_t n, const float* x) {
codebooks.resize(d * codebook_offsets.back());
if (verbose) {
printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n",
M,
n,
size_t(d));
}
int cur_beam_size = 1;
std::vector<float> residuals(x, x + n * d);
std::vector<int32_t> codes;
std::vector<float> distances;
double t0 = getmillisecs();
double clustering_time = 0;
for (int m = 0; m < M; m++) {
int K = 1 << nbits[m];
// on which residuals to train
std::vector<float>& train_residuals = residuals;
std::vector<float> residuals1;
if (train_type & Train_top_beam) {
residuals1.resize(n * d);
for (size_t j = 0; j < n; j++) {
memcpy(residuals1.data() + j * d,
residuals.data() + j * d * cur_beam_size,
sizeof(residuals[0]) * d);
}
train_residuals = residuals1;
}
std::vector<float> codebooks;
float obj = 0;
std::unique_ptr<Index> assign_index;
if (assign_index_factory) {
assign_index.reset((*assign_index_factory)(d));
} else {
assign_index.reset(new IndexFlatL2(d));
}
double t1 = getmillisecs();
if (!(train_type & Train_progressive_dim)) { // regular kmeans
Clustering clus(d, K, cp);
clus.train(
train_residuals.size() / d,
train_residuals.data(),
*assign_index.get());
codebooks.swap(clus.centroids);
assign_index->reset();
obj = clus.iteration_stats.back().obj;
} else { // progressive dim clustering
ProgressiveDimClustering clus(d, K, cp);
ProgressiveDimIndexFactory default_fac;
clus.train(
train_residuals.size() / d,
train_residuals.data(),
assign_index_factory ? *assign_index_factory : default_fac);
codebooks.swap(clus.centroids);
obj = clus.iteration_stats.back().obj;
}
clustering_time += (getmillisecs() - t1) / 1000;
memcpy(this->codebooks.data() + codebook_offsets[m] * d,
codebooks.data(),
codebooks.size() * sizeof(codebooks[0]));
// quantize using the new codebooks
int new_beam_size = std::min(cur_beam_size * K, max_beam_size);
std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
std::vector<float> new_residuals(n * new_beam_size * d);
std::vector<float> new_distances(n * new_beam_size);
size_t bs;
{ // determine batch size
size_t mem = memory_per_point();
if (n > 1 && mem * n > max_mem_distances) {
// then split queries to reduce temp memory
bs = std::max(max_mem_distances / mem, size_t(1));
} else {
bs = n;
}
}
for (size_t i0 = 0; i0 < n; i0 += bs) {
size_t i1 = std::min(i0 + bs, n);
/* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
i0, i1, K, assign_index->ntotal); */
beam_search_encode_step(
d,
K,
codebooks.data(),
i1 - i0,
cur_beam_size,
residuals.data() + i0 * cur_beam_size * d,
m,
codes.data() + i0 * cur_beam_size * m,
new_beam_size,
new_codes.data() + i0 * new_beam_size * (m + 1),
new_residuals.data() + i0 * new_beam_size * d,
new_distances.data() + i0 * new_beam_size,
assign_index.get());
}
codes.swap(new_codes);
residuals.swap(new_residuals);
distances.swap(new_distances);
float sum_distances = 0;
for (int j = 0; j < distances.size(); j++) {
sum_distances += distances[j];
}
if (verbose) {
printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
"total distance %g, beam_size %d->%d (batch size %zd)\n",
(getmillisecs() - t0) / 1000,
clustering_time,
m,
int(nbits[m]),
obj,
sum_distances,
cur_beam_size,
new_beam_size,
bs);
}
cur_beam_size = new_beam_size;
}
is_trained = true;
if (train_type & Train_refine_codebook) {
for (int iter = 0; iter < niter_codebook_refine; iter++) {
if (verbose) {
printf("re-estimating the codebooks to minimize "
"quantization errors (iter %d).\n",
iter);
}
retrain_AQ_codebook(n, x);
}
}
// find min and max norms
std::vector<float> norms(n);
for (size_t i = 0; i < n; i++) {
norms[i] = fvec_L2sqr(
x + i * d, residuals.data() + i * cur_beam_size * d, d);
}
norm_min = HUGE_VALF;
norm_max = -HUGE_VALF;
for (idx_t i = 0; i < n; i++) {
if (norms[i] < norm_min) {
norm_min = norms[i];
}
if (norms[i] > norm_max) {
norm_max = norms[i];
}
}
if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
size_t k = (1 << 8);
if (search_type == ST_norm_cqint4) {
k = (1 << 4);
}
Clustering1D clus(k);
clus.train_exact(n, norms.data());
qnorm.add(clus.k, clus.centroids.data());
}
if (!(train_type & Skip_codebook_tables)) {
compute_codebook_tables();
}
}