uint16_t simd_partition_fuzzy_with_bounds_histogram()

in faiss/utils/partitioning.cpp [515:685]


uint16_t simd_partition_fuzzy_with_bounds_histogram(
        uint16_t* vals,
        typename C::TI* ids,
        size_t n,
        size_t q_min,
        size_t q_max,
        size_t* q_out,
        uint16_t s0i,
        uint16_t s1i) {
    if (q_min == 0) {
        if (q_out) {
            *q_out = 0;
        }
        return 0;
    }
    if (q_max >= n) {
        if (q_out) {
            *q_out = q_max;
        }
        return 0xffff;
    }
    if (s0i == s1i) {
        if (q_out) {
            *q_out = q_min;
        }
        return s0i;
    }

    IFV printf(
            "partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
            q_min,
            q_max,
            n,
            s0i,
            s1i);

    if (!C::is_max) {
        IFV printf(
                "revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
        q_min = n - q_min;
        q_max = n - q_max;
    }

    // lower and upper bound of range, inclusive
    int s0 = s0i, s1 = s1i;
    // number of values < s0 and > s1
    size_t n_lt = 0, n_gt = 0;

    // output of loop:
    int thresh;          // final threshold
    uint64_t tot_eq = 0; // total nb of equal values
    uint64_t n_eq = 0;   // nb of equal values to keep
    size_t q;            // final quantile

    // buffer for the histograms
    int hist[16];

    for (int it = 0; it < 20; it++) {
        // otherwise we would be done already

        int shift = 0;

        IFV printf(
                "  it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
                it,
                s0,
                s1,
                n_lt,
                n_gt);

        int maxval = s1 - s0;

        while (maxval > 15) {
            shift++;
            maxval >>= 1;
        }

        IFV printf(
                "    histogram shift %d maxval %d ?= %d\n",
                shift,
                maxval,
                int((s1 - s0) >> shift));

        if (maxval > 7) {
            simd_histogram_16(vals, n, s0, shift, hist);
        } else {
            simd_histogram_8(vals, n, s0, shift, hist);
        }
        IFV {
            int sum = n_lt + n_gt;
            printf("    n_lt=%ld hist=[", n_lt);
            for (int i = 0; i <= maxval; i++) {
                printf("%d ", hist[i]);
                sum += hist[i];
            }
            printf("] n_gt=%ld sum=%d\n", n_gt, sum);
            assert(sum == n);
        }

        size_t sum_below = n_lt;
        int i;
        for (i = 0; i <= maxval; i++) {
            sum_below += hist[i];
            if (sum_below >= q_min) {
                break;
            }
        }
        IFV printf("    i=%d sum_below=%ld\n", i, sum_below);
        if (i <= maxval) {
            s0 = s0 + (i << shift);
            s1 = s0 + (1 << shift) - 1;
            n_lt = sum_below - hist[i];
            n_gt = n - sum_below;
        } else {
            assert(!"not implemented");
        }

        IFV printf(
                "    new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n",
                s0,
                s1,
                n_lt,
                n_gt);

        if (s1 > s0) {
            if (n_lt >= q_min && q_max >= n_lt) {
                IFV printf("    FOUND1\n");
                thresh = s0;
                q = n_lt;
                break;
            }

            size_t n_lt_2 = n - n_gt;
            if (n_lt_2 >= q_min && q_max >= n_lt_2) {
                thresh = s1 + 1;
                q = n_lt_2;
                IFV printf("    FOUND2\n");
                break;
            }
        } else {
            thresh = s0;
            q = q_min;
            tot_eq = n - n_gt - n_lt;
            n_eq = q_min - n_lt;
            IFV printf("    FOUND3\n");
            break;
        }
    }

    IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);

    if (!C::is_max) {
        if (n_eq == 0) {
            thresh--;
        } else {
            // thresh unchanged
            n_eq = tot_eq - n_eq;
        }
        q = n - q;
        IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
    }

    size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
    IFV printf("wp=%ld ?= %ld\n", wp, q);
    assert(wp == q);
    if (q_out) {
        *q_out = wp;
    }

    return thresh;
}