in faiss/impl/ResidualQuantizer.cpp [105:235]
void beam_search_encode_step(
size_t d,
size_t K,
const float* cent, /// size (K, d)
size_t n,
size_t beam_size,
const float* residuals, /// size (n, beam_size, d)
size_t m,
const int32_t* codes, /// size (n, beam_size, m)
size_t new_beam_size,
int32_t* new_codes, /// size (n, new_beam_size, m + 1)
float* new_residuals, /// size (n, new_beam_size, d)
float* new_distances, /// size (n, new_beam_size)
Index* assign_index) {
// we have to fill in the whole output matrix
FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
using idx_t = Index::idx_t;
std::vector<float> cent_distances;
std::vector<idx_t> cent_ids;
if (assign_index) {
// search beam_size distances per query
FAISS_THROW_IF_NOT(assign_index->d == d);
cent_distances.resize(n * beam_size * new_beam_size);
cent_ids.resize(n * beam_size * new_beam_size);
if (assign_index->ntotal != 0) {
// then we assume the codebooks are already added to the index
FAISS_THROW_IF_NOT(assign_index->ntotal == K);
} else {
assign_index->add(K, cent);
}
// printf("beam_search_encode_step -- mem usage %zd\n",
// get_mem_usage_kb());
assign_index->search(
n * beam_size,
residuals,
new_beam_size,
cent_distances.data(),
cent_ids.data());
} else {
// do one big distance computation
cent_distances.resize(n * beam_size * K);
pairwise_L2sqr(
d, n * beam_size, residuals, K, cent, cent_distances.data());
}
InterruptCallback::check();
#pragma omp parallel for if (n > 100)
for (int64_t i = 0; i < n; i++) {
const int32_t* codes_i = codes + i * m * beam_size;
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
const float* residuals_i = residuals + i * d * beam_size;
float* new_residuals_i = new_residuals + i * d * new_beam_size;
float* new_distances_i = new_distances + i * new_beam_size;
using C = CMax<float, int>;
if (assign_index) {
const float* cent_distances_i =
cent_distances.data() + i * beam_size * new_beam_size;
const idx_t* cent_ids_i =
cent_ids.data() + i * beam_size * new_beam_size;
// here we could be a tad more efficient by merging sorted arrays
for (int i = 0; i < new_beam_size; i++) {
new_distances_i[i] = C::neutral();
}
std::vector<int> perm(new_beam_size, -1);
heap_addn<C>(
new_beam_size,
new_distances_i,
perm.data(),
cent_distances_i,
nullptr,
beam_size * new_beam_size);
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
for (int j = 0; j < new_beam_size; j++) {
int js = perm[j] / new_beam_size;
int ls = cent_ids_i[perm[j]];
if (m > 0) {
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
}
new_codes_i[m] = ls;
new_codes_i += m + 1;
fvec_sub(
d,
residuals_i + js * d,
cent + ls * d,
new_residuals_i);
new_residuals_i += d;
}
} else {
const float* cent_distances_i =
cent_distances.data() + i * beam_size * K;
// then we have to select the best results
for (int i = 0; i < new_beam_size; i++) {
new_distances_i[i] = C::neutral();
}
std::vector<int> perm(new_beam_size, -1);
heap_addn<C>(
new_beam_size,
new_distances_i,
perm.data(),
cent_distances_i,
nullptr,
beam_size * K);
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
for (int j = 0; j < new_beam_size; j++) {
int js = perm[j] / K;
int ls = perm[j] % K;
if (m > 0) {
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
}
new_codes_i[m] = ls;
new_codes_i += m + 1;
fvec_sub(
d,
residuals_i + js * d,
cent + ls * d,
new_residuals_i);
new_residuals_i += d;
}
}
}
}