in faiss/IndexIVFPQFastScan.cpp [953:1146]
void IndexIVFPQFastScan::search_implem_12(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
int impl,
size_t* ndis_out,
size_t* nlist_out) const {
if (n == 0) { // does not work well with reservoir
return;
}
FAISS_THROW_IF_NOT(bbs == 32);
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
uint64_t times[10];
memset(times, 0, sizeof(times));
int ti = 0;
#define TIC times[ti++] = get_cy()
TIC;
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
TIC;
size_t dim12 = pq.ksub * M2;
AlignedTable<uint8_t> dis_tables;
AlignedTable<uint16_t> biases;
std::unique_ptr<float[]> normalizers(new float[2 * n]);
compute_LUT_uint8(
n,
x,
coarse_ids.get(),
coarse_dis.get(),
dis_tables,
biases,
normalizers.get());
TIC;
struct QC {
int qno; // sequence number of the query
int list_no; // list to visit
int rank; // this is the rank'th result of the coarse quantizer
};
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
std::vector<QC> qcs;
{
int ij = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < nprobe; j++) {
if (coarse_ids[ij] >= 0) {
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
}
ij++;
}
}
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
return a.list_no < b.list_no;
});
}
TIC;
// prepare the result handlers
std::unique_ptr<SIMDResultHandler<C, true>> handler;
AlignedTable<uint16_t> tmp_distances;
using HeapHC = HeapHandler<C, true>;
using ReservoirHC = ReservoirHandler<C, true>;
using SingleResultHC = SingleResultHandler<C, true>;
if (k == 1) {
handler.reset(new SingleResultHC(n, 0));
} else if (impl == 12) {
tmp_distances.resize(n * k);
handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
} else if (impl == 13) {
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
}
int qbs2 = this->qbs2 ? this->qbs2 : 11;
std::vector<uint16_t> tmp_bias;
if (biases.get()) {
tmp_bias.resize(qbs2);
handler->dbias = tmp_bias.data();
}
TIC;
size_t ndis = 0;
size_t i0 = 0;
uint64_t t_copy_pack = 0, t_scan = 0;
while (i0 < qcs.size()) {
uint64_t tt0 = get_cy();
// find all queries that access this inverted list
int list_no = qcs[i0].list_no;
size_t i1 = i0 + 1;
while (i1 < qcs.size() && i1 < i0 + qbs2) {
if (qcs[i1].list_no != list_no) {
break;
}
i1++;
}
size_t list_size = invlists->list_size(list_no);
if (list_size == 0) {
i0 = i1;
continue;
}
// re-organize LUTs and biases into the right order
int nc = i1 - i0;
std::vector<int> q_map(nc), lut_entries(nc);
AlignedTable<uint8_t> LUT(nc * dim12);
memset(LUT.get(), -1, nc * dim12);
int qbs = pq4_preferred_qbs(nc);
for (size_t i = i0; i < i1; i++) {
const QC& qc = qcs[i];
q_map[i - i0] = qc.qno;
int ij = qc.qno * nprobe + qc.rank;
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
if (biases.get()) {
tmp_bias[i - i0] = biases[ij];
}
}
pq4_pack_LUT_qbs_q_map(
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
// access the inverted list
ndis += (i1 - i0) * list_size;
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
// prepare the handler
handler->ntotal = list_size;
handler->q_map = q_map.data();
handler->id_map = ids.get();
uint64_t tt1 = get_cy();
#define DISPATCH(classHC) \
if (dynamic_cast<classHC*>(handler.get())) { \
auto* res = static_cast<classHC*>(handler.get()); \
pq4_accumulate_loop_qbs( \
qbs, list_size, M2, codes.get(), LUT.get(), *res); \
}
DISPATCH(HeapHC)
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
// prepare for next loop
i0 = i1;
uint64_t tt2 = get_cy();
t_copy_pack += tt1 - tt0;
t_scan += tt2 - tt1;
}
TIC;
// labels is in-place for HeapHC
handler->to_flat_arrays(
distances, labels, skip & 16 ? nullptr : normalizers.get());
TIC;
// these stats are not thread-safe
for (int i = 1; i < ti; i++) {
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
}
IVFFastScan_stats.t_copy_pack += t_copy_pack;
IVFFastScan_stats.t_scan += t_scan;
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
for (int i = 0; i < 4; i++) {
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
}
}
*ndis_out = ndis;
*nlist_out = nlist;
}