in faiss/IndexHNSW.cpp [677:783]
void ReconstructFromNeighbors::reconstruct(
storage_idx_t i,
float* x,
float* tmp) const {
const HNSW& hnsw = index.hnsw;
size_t begin, end;
hnsw.neighbor_range(i, 0, &begin, &end);
if (k == 1 || nsq == 1) {
const float* beta;
if (k == 1) {
beta = codebook.data();
} else {
int idx = codes[i];
beta = codebook.data() + idx * (M + 1);
}
float w0 = beta[0]; // weight of image itself
index.storage->reconstruct(i, tmp);
for (int l = 0; l < d; l++)
x[l] = w0 * tmp[l];
for (size_t j = begin; j < end; j++) {
storage_idx_t ji = hnsw.neighbors[j];
if (ji < 0)
ji = i;
float w = beta[j - begin + 1];
index.storage->reconstruct(ji, tmp);
for (int l = 0; l < d; l++)
x[l] += w * tmp[l];
}
} else if (nsq == 2) {
int idx0 = codes[2 * i];
int idx1 = codes[2 * i + 1];
const float* beta0 = codebook.data() + idx0 * (M + 1);
const float* beta1 = codebook.data() + (idx1 + k) * (M + 1);
index.storage->reconstruct(i, tmp);
float w0;
w0 = beta0[0];
for (int l = 0; l < dsub; l++)
x[l] = w0 * tmp[l];
w0 = beta1[0];
for (int l = dsub; l < d; l++)
x[l] = w0 * tmp[l];
for (size_t j = begin; j < end; j++) {
storage_idx_t ji = hnsw.neighbors[j];
if (ji < 0)
ji = i;
index.storage->reconstruct(ji, tmp);
float w;
w = beta0[j - begin + 1];
for (int l = 0; l < dsub; l++)
x[l] += w * tmp[l];
w = beta1[j - begin + 1];
for (int l = dsub; l < d; l++)
x[l] += w * tmp[l];
}
} else {
std::vector<const float*> betas(nsq);
{
const float* b = codebook.data();
const uint8_t* c = &codes[i * code_size];
for (int sq = 0; sq < nsq; sq++) {
betas[sq] = b + (*c++) * (M + 1);
b += (M + 1) * k;
}
}
index.storage->reconstruct(i, tmp);
{
int d0 = 0;
for (int sq = 0; sq < nsq; sq++) {
float w = *(betas[sq]++);
int d1 = d0 + dsub;
for (int l = d0; l < d1; l++) {
x[l] = w * tmp[l];
}
d0 = d1;
}
}
for (size_t j = begin; j < end; j++) {
storage_idx_t ji = hnsw.neighbors[j];
if (ji < 0)
ji = i;
index.storage->reconstruct(ji, tmp);
int d0 = 0;
for (int sq = 0; sq < nsq; sq++) {
float w = *(betas[sq]++);
int d1 = d0 + dsub;
for (int l = d0; l < d1; l++) {
x[l] += w * tmp[l];
}
d0 = d1;
}
}
}
}