in faiss/Clustering.cpp [271:555]
void Clustering::train_encoded(
idx_t nx,
const uint8_t* x_in,
const Index* codec,
Index& index,
const float* weights) {
FAISS_THROW_IF_NOT_FMT(
nx >= k,
"Number of training points (%" PRId64
") should be at least "
"as large as number of clusters (%zd)",
nx,
k);
FAISS_THROW_IF_NOT_FMT(
(!codec || codec->d == d),
"Codec dimension %d not the same as data dimension %d",
int(codec->d),
int(d));
FAISS_THROW_IF_NOT_FMT(
index.d == d,
"Index dimension %d not the same as data dimension %d",
int(index.d),
int(d));
double t0 = getmillisecs();
if (!codec) {
// Check for NaNs in input data. Normally it is the user's
// responsibility, but it may spare us some hard-to-debug
// reports.
const float* x = reinterpret_cast<const float*>(x_in);
for (size_t i = 0; i < nx * d; i++) {
FAISS_THROW_IF_NOT_MSG(
std::isfinite(x[i]), "input contains NaN's or Inf's");
}
}
const uint8_t* x = x_in;
std::unique_ptr<uint8_t[]> del1;
std::unique_ptr<float[]> del3;
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
if (nx > k * max_points_per_centroid) {
uint8_t* x_new;
float* weights_new;
nx = subsample_training_set(
*this, nx, x, line_size, weights, &x_new, &weights_new);
del1.reset(x_new);
x = x_new;
del3.reset(weights_new);
weights = weights_new;
} else if (nx < k * min_points_per_centroid) {
fprintf(stderr,
"WARNING clustering %" PRId64
" points to %zd centroids: "
"please provide at least %" PRId64 " training points\n",
nx,
k,
idx_t(k) * min_points_per_centroid);
}
if (nx == k) {
// this is a corner case, just copy training set to clusters
if (verbose) {
printf("Number of training points (%" PRId64
") same as number of "
"clusters, just copying\n",
nx);
}
centroids.resize(d * k);
if (!codec) {
memcpy(centroids.data(), x_in, sizeof(float) * d * k);
} else {
codec->sa_decode(nx, x_in, centroids.data());
}
// one fake iteration...
ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
iteration_stats.push_back(stats);
index.reset();
index.add(k, centroids.data());
return;
}
if (verbose) {
printf("Clustering %" PRId64
" points in %zdD to %zd clusters, "
"redo %d times, %d iterations\n",
nx,
d,
k,
nredo,
niter);
if (codec) {
printf("Input data encoded in %zd bytes per vector\n",
codec->sa_code_size());
}
}
std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
std::unique_ptr<float[]> dis(new float[nx]);
// remember best iteration for redo
bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
float best_obj = lower_is_better ? HUGE_VALF : -HUGE_VALF;
std::vector<ClusteringIterationStats> best_iteration_stats;
std::vector<float> best_centroids;
// support input centroids
FAISS_THROW_IF_NOT_MSG(
centroids.size() % d == 0,
"size of provided input centroids not a multiple of dimension");
size_t n_input_centroids = centroids.size() / d;
if (verbose && n_input_centroids > 0) {
printf(" Using %zd centroids provided as input (%sfrozen)\n",
n_input_centroids,
frozen_centroids ? "" : "not ");
}
double t_search_tot = 0;
if (verbose) {
printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
}
t0 = getmillisecs();
// temporary buffer to decode vectors during the optimization
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
for (int redo = 0; redo < nredo; redo++) {
if (verbose && nredo > 1) {
printf("Outer iteration %d / %d\n", redo, nredo);
}
// initialize (remaining) centroids with random points from the dataset
centroids.resize(d * k);
std::vector<int> perm(nx);
rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
if (!codec) {
for (int i = n_input_centroids; i < k; i++) {
memcpy(¢roids[i * d], x + perm[i] * line_size, line_size);
}
} else {
for (int i = n_input_centroids; i < k; i++) {
codec->sa_decode(1, x + perm[i] * line_size, ¢roids[i * d]);
}
}
post_process_centroids();
// prepare the index
if (index.ntotal != 0) {
index.reset();
}
if (!index.is_trained) {
index.train(k, centroids.data());
}
index.add(k, centroids.data());
// k-means iterations
float obj = 0;
for (int i = 0; i < niter; i++) {
double t0s = getmillisecs();
if (!codec) {
index.search(
nx,
reinterpret_cast<const float*>(x),
1,
dis.get(),
assign.get());
} else {
// search by blocks of decode_block_size vectors
size_t code_size = codec->sa_code_size();
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
size_t i1 = i0 + decode_block_size;
if (i1 > nx) {
i1 = nx;
}
codec->sa_decode(
i1 - i0, x + code_size * i0, decode_buffer.data());
index.search(
i1 - i0,
decode_buffer.data(),
1,
dis.get() + i0,
assign.get() + i0);
}
}
InterruptCallback::check();
t_search_tot += getmillisecs() - t0s;
// accumulate objective
obj = 0;
for (int j = 0; j < nx; j++) {
obj += dis[j];
}
// update the centroids
std::vector<float> hassign(k);
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
compute_centroids(
d,
k,
nx,
k_frozen,
x,
codec,
assign.get(),
weights,
hassign.data(),
centroids.data());
int nsplit = split_clusters(
d, k, nx, k_frozen, hassign.data(), centroids.data());
// collect statistics
ClusteringIterationStats stats = {
obj,
(getmillisecs() - t0) / 1000.0,
t_search_tot / 1000,
imbalance_factor(nx, k, assign.get()),
nsplit};
iteration_stats.push_back(stats);
if (verbose) {
printf(" Iteration %d (%.2f s, search %.2f s): "
"objective=%g imbalance=%.3f nsplit=%d \r",
i,
stats.time,
stats.time_search,
stats.obj,
stats.imbalance_factor,
nsplit);
fflush(stdout);
}
post_process_centroids();
// add centroids to index for the next iteration (or for output)
index.reset();
if (update_index) {
index.train(k, centroids.data());
}
index.add(k, centroids.data());
InterruptCallback::check();
}
if (verbose)
printf("\n");
if (nredo > 1) {
if ((lower_is_better && obj < best_obj) ||
(!lower_is_better && obj > best_obj)) {
if (verbose) {
printf("Objective improved: keep new clusters\n");
}
best_centroids = centroids;
best_iteration_stats = iteration_stats;
best_obj = obj;
}
index.reset();
}
}
if (nredo > 1) {
centroids = best_centroids;
iteration_stats = best_iteration_stats;
index.reset();
index.add(k, best_centroids.data());
}
}