in faiss/utils/quantize_lut.cpp [140:285]
void quantize_LUT_and_bias(
size_t nprobe,
size_t M,
size_t ksub,
bool lut_is_3d,
const float* LUT,
const float* bias,
uint8_t* LUTq,
size_t M2,
uint16_t* biasq,
float* a_out,
float* b_out) {
float a, b;
if (!bias) {
FAISS_THROW_IF_NOT(!lut_is_3d);
std::vector<float> mins(M);
float max_span_LUT = -HUGE_VAL, max_span_dis = 0;
b = 0;
for (int i = 0; i < M; i++) {
mins[i] = tab_min(LUT + i * ksub, ksub);
float span = tab_max(LUT + i * ksub, ksub) - mins[i];
max_span_LUT = std::max(max_span_LUT, span);
max_span_dis += span;
b += mins[i];
}
a = std::min(255 / max_span_LUT, 65535 / max_span_dis);
for (int i = 0; i < M; i++) {
round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub);
}
memset(LUTq + M * ksub, 0, ksub * (M2 - M));
} else if (!lut_is_3d) {
std::vector<float> mins(M);
float max_span_LUT = -HUGE_VAL, max_span_dis;
float bias_min = tab_min(bias, nprobe);
float bias_max = tab_max(bias, nprobe);
max_span_dis = bias_max - bias_min;
b = 0;
for (int i = 0; i < M; i++) {
mins[i] = tab_min(LUT + i * ksub, ksub);
float span = tab_max(LUT + i * ksub, ksub) - mins[i];
max_span_LUT = std::max(max_span_LUT, span);
max_span_dis += span;
b += mins[i];
}
a = std::min(255 / max_span_LUT, 65535 / max_span_dis);
b += bias_min;
for (int i = 0; i < M; i++) {
round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub);
}
memset(LUTq + M * ksub, 0, ksub * (M2 - M));
round_tab(bias, nprobe, a, bias_min, biasq);
} else if (biasq) {
// LUT is 3D
std::vector<float> mins(nprobe * M);
std::vector<float> bias2(nprobe);
float bias_min = tab_min(bias, nprobe);
float max_span_LUT = -HUGE_VAL, max_span_dis = -HUGE_VAL;
b = HUGE_VAL;
size_t ij = 0;
for (int j = 0; j < nprobe; j++) {
float max_span_dis_j = bias[j] - bias_min;
float b2j = bias[j];
for (int i = 0; i < M; i++) {
mins[ij] = tab_min(LUT + ij * ksub, ksub);
float span = tab_max(LUT + ij * ksub, ksub) - mins[ij];
max_span_LUT = std::max(max_span_LUT, span);
max_span_dis_j += span;
b2j += mins[ij];
ij++;
}
max_span_dis = std::max(max_span_dis, max_span_dis_j);
bias2[j] = b2j;
b = std::min(b, b2j);
}
a = std::min(255 / max_span_LUT, 65535 / max_span_dis);
ij = 0;
size_t ij_2 = 0;
for (int j = 0; j < nprobe; j++) {
for (int i = 0; i < M; i++) {
round_tab(
LUT + ij * ksub, ksub, a, mins[ij], LUTq + ij_2 * ksub);
ij++;
ij_2++;
}
memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M));
ij_2 += M2 - M;
}
round_tab(bias2.data(), nprobe, a, b, biasq);
} else { // !biasq
// then we integrate the bias into the LUTs
std::vector<float> LUT2_storage(nprobe * M * ksub);
float* LUT2 = LUT2_storage.data();
size_t ijc = 0;
for (int j = 0; j < nprobe; j++) {
float bias_j = bias[j] / M;
for (int i = 0; i < M; i++) {
for (int c = 0; c < ksub; c++) {
LUT2[ijc] = LUT[ijc] + bias_j;
ijc++;
}
}
}
std::vector<float> mins(M, HUGE_VAL), maxs(M, -HUGE_VAL);
size_t ij = 0;
for (int j = 0; j < nprobe; j++) {
for (int i = 0; i < M; i++) {
mins[i] = std::min(mins[i], tab_min(LUT2 + ij * ksub, ksub));
maxs[i] = std::max(maxs[i], tab_max(LUT2 + ij * ksub, ksub));
ij++;
}
}
float max_span = -HUGE_VAL;
b = 0;
for (int i = 0; i < M; i++) {
float span = maxs[i] - mins[i];
max_span = std::max(max_span, span);
b += mins[i];
}
a = 255 / max_span;
ij = 0;
size_t ij_2 = 0;
for (int j = 0; j < nprobe; j++) {
for (int i = 0; i < M; i++) {
round_tab(
LUT2 + ij * ksub, ksub, a, mins[i], LUTq + ij_2 * ksub);
ij++;
ij_2++;
}
memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M));
ij_2 += M2 - M;
}
}
if (a_out)
*a_out = a;
if (b_out)
*b_out = b;
}