void quantize_LUT_and_bias()

in faiss/utils/quantize_lut.cpp [140:285]


void quantize_LUT_and_bias(
        size_t nprobe,
        size_t M,
        size_t ksub,
        bool lut_is_3d,
        const float* LUT,
        const float* bias,
        uint8_t* LUTq,
        size_t M2,
        uint16_t* biasq,
        float* a_out,
        float* b_out) {
    float a, b;
    if (!bias) {
        FAISS_THROW_IF_NOT(!lut_is_3d);
        std::vector<float> mins(M);
        float max_span_LUT = -HUGE_VAL, max_span_dis = 0;
        b = 0;
        for (int i = 0; i < M; i++) {
            mins[i] = tab_min(LUT + i * ksub, ksub);
            float span = tab_max(LUT + i * ksub, ksub) - mins[i];
            max_span_LUT = std::max(max_span_LUT, span);
            max_span_dis += span;
            b += mins[i];
        }
        a = std::min(255 / max_span_LUT, 65535 / max_span_dis);

        for (int i = 0; i < M; i++) {
            round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub);
        }
        memset(LUTq + M * ksub, 0, ksub * (M2 - M));
    } else if (!lut_is_3d) {
        std::vector<float> mins(M);
        float max_span_LUT = -HUGE_VAL, max_span_dis;
        float bias_min = tab_min(bias, nprobe);
        float bias_max = tab_max(bias, nprobe);
        max_span_dis = bias_max - bias_min;
        b = 0;
        for (int i = 0; i < M; i++) {
            mins[i] = tab_min(LUT + i * ksub, ksub);
            float span = tab_max(LUT + i * ksub, ksub) - mins[i];
            max_span_LUT = std::max(max_span_LUT, span);
            max_span_dis += span;
            b += mins[i];
        }
        a = std::min(255 / max_span_LUT, 65535 / max_span_dis);
        b += bias_min;

        for (int i = 0; i < M; i++) {
            round_tab(LUT + i * ksub, ksub, a, mins[i], LUTq + i * ksub);
        }
        memset(LUTq + M * ksub, 0, ksub * (M2 - M));
        round_tab(bias, nprobe, a, bias_min, biasq);

    } else if (biasq) {
        // LUT is 3D
        std::vector<float> mins(nprobe * M);
        std::vector<float> bias2(nprobe);
        float bias_min = tab_min(bias, nprobe);
        float max_span_LUT = -HUGE_VAL, max_span_dis = -HUGE_VAL;

        b = HUGE_VAL;
        size_t ij = 0;
        for (int j = 0; j < nprobe; j++) {
            float max_span_dis_j = bias[j] - bias_min;
            float b2j = bias[j];
            for (int i = 0; i < M; i++) {
                mins[ij] = tab_min(LUT + ij * ksub, ksub);
                float span = tab_max(LUT + ij * ksub, ksub) - mins[ij];
                max_span_LUT = std::max(max_span_LUT, span);
                max_span_dis_j += span;
                b2j += mins[ij];
                ij++;
            }
            max_span_dis = std::max(max_span_dis, max_span_dis_j);
            bias2[j] = b2j;
            b = std::min(b, b2j);
        }

        a = std::min(255 / max_span_LUT, 65535 / max_span_dis);

        ij = 0;
        size_t ij_2 = 0;
        for (int j = 0; j < nprobe; j++) {
            for (int i = 0; i < M; i++) {
                round_tab(
                        LUT + ij * ksub, ksub, a, mins[ij], LUTq + ij_2 * ksub);
                ij++;
                ij_2++;
            }
            memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M));
            ij_2 += M2 - M;
        }

        round_tab(bias2.data(), nprobe, a, b, biasq);

    } else { // !biasq
        // then we integrate the bias into the LUTs
        std::vector<float> LUT2_storage(nprobe * M * ksub);
        float* LUT2 = LUT2_storage.data();
        size_t ijc = 0;
        for (int j = 0; j < nprobe; j++) {
            float bias_j = bias[j] / M;
            for (int i = 0; i < M; i++) {
                for (int c = 0; c < ksub; c++) {
                    LUT2[ijc] = LUT[ijc] + bias_j;
                    ijc++;
                }
            }
        }
        std::vector<float> mins(M, HUGE_VAL), maxs(M, -HUGE_VAL);
        size_t ij = 0;
        for (int j = 0; j < nprobe; j++) {
            for (int i = 0; i < M; i++) {
                mins[i] = std::min(mins[i], tab_min(LUT2 + ij * ksub, ksub));
                maxs[i] = std::max(maxs[i], tab_max(LUT2 + ij * ksub, ksub));
                ij++;
            }
        }

        float max_span = -HUGE_VAL;
        b = 0;
        for (int i = 0; i < M; i++) {
            float span = maxs[i] - mins[i];
            max_span = std::max(max_span, span);
            b += mins[i];
        }
        a = 255 / max_span;
        ij = 0;
        size_t ij_2 = 0;
        for (int j = 0; j < nprobe; j++) {
            for (int i = 0; i < M; i++) {
                round_tab(
                        LUT2 + ij * ksub, ksub, a, mins[i], LUTq + ij_2 * ksub);
                ij++;
                ij_2++;
            }
            memset(LUTq + ij_2 * ksub, 0, ksub * (M2 - M));
            ij_2 += M2 - M;
        }
    }
    if (a_out)
        *a_out = a;
    if (b_out)
        *b_out = b;
}