in faiss/utils/partitioning.cpp [515:685]
uint16_t simd_partition_fuzzy_with_bounds_histogram(
uint16_t* vals,
typename C::TI* ids,
size_t n,
size_t q_min,
size_t q_max,
size_t* q_out,
uint16_t s0i,
uint16_t s1i) {
if (q_min == 0) {
if (q_out) {
*q_out = 0;
}
return 0;
}
if (q_max >= n) {
if (q_out) {
*q_out = q_max;
}
return 0xffff;
}
if (s0i == s1i) {
if (q_out) {
*q_out = q_min;
}
return s0i;
}
IFV printf(
"partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
q_min,
q_max,
n,
s0i,
s1i);
if (!C::is_max) {
IFV printf(
"revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
q_min = n - q_min;
q_max = n - q_max;
}
// lower and upper bound of range, inclusive
int s0 = s0i, s1 = s1i;
// number of values < s0 and > s1
size_t n_lt = 0, n_gt = 0;
// output of loop:
int thresh; // final threshold
uint64_t tot_eq = 0; // total nb of equal values
uint64_t n_eq = 0; // nb of equal values to keep
size_t q; // final quantile
// buffer for the histograms
int hist[16];
for (int it = 0; it < 20; it++) {
// otherwise we would be done already
int shift = 0;
IFV printf(
" it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
it,
s0,
s1,
n_lt,
n_gt);
int maxval = s1 - s0;
while (maxval > 15) {
shift++;
maxval >>= 1;
}
IFV printf(
" histogram shift %d maxval %d ?= %d\n",
shift,
maxval,
int((s1 - s0) >> shift));
if (maxval > 7) {
simd_histogram_16(vals, n, s0, shift, hist);
} else {
simd_histogram_8(vals, n, s0, shift, hist);
}
IFV {
int sum = n_lt + n_gt;
printf(" n_lt=%ld hist=[", n_lt);
for (int i = 0; i <= maxval; i++) {
printf("%d ", hist[i]);
sum += hist[i];
}
printf("] n_gt=%ld sum=%d\n", n_gt, sum);
assert(sum == n);
}
size_t sum_below = n_lt;
int i;
for (i = 0; i <= maxval; i++) {
sum_below += hist[i];
if (sum_below >= q_min) {
break;
}
}
IFV printf(" i=%d sum_below=%ld\n", i, sum_below);
if (i <= maxval) {
s0 = s0 + (i << shift);
s1 = s0 + (1 << shift) - 1;
n_lt = sum_below - hist[i];
n_gt = n - sum_below;
} else {
assert(!"not implemented");
}
IFV printf(
" new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n",
s0,
s1,
n_lt,
n_gt);
if (s1 > s0) {
if (n_lt >= q_min && q_max >= n_lt) {
IFV printf(" FOUND1\n");
thresh = s0;
q = n_lt;
break;
}
size_t n_lt_2 = n - n_gt;
if (n_lt_2 >= q_min && q_max >= n_lt_2) {
thresh = s1 + 1;
q = n_lt_2;
IFV printf(" FOUND2\n");
break;
}
} else {
thresh = s0;
q = q_min;
tot_eq = n - n_gt - n_lt;
n_eq = q_min - n_lt;
IFV printf(" FOUND3\n");
break;
}
}
IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);
if (!C::is_max) {
if (n_eq == 0) {
thresh--;
} else {
// thresh unchanged
n_eq = tot_eq - n_eq;
}
q = n - q;
IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
}
size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
IFV printf("wp=%ld ?= %ld\n", wp, q);
assert(wp == q);
if (q_out) {
*q_out = wp;
}
return thresh;
}