in pir_server.cpp [130:258]
PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
vector<uint64_t> nvec = pir_params_.nvec;
uint64_t product = 1;
for (uint32_t i = 0; i < nvec.size(); i++) {
product *= nvec[i];
}
auto coeff_count = params_.poly_modulus_degree();
vector<Plaintext> *cur = db_.get();
vector<Plaintext> intermediate_plain; // decompose....
auto pool = MemoryManager::GetPool();
int N = params_.poly_modulus_degree();
int logt = floor(log2(params_.plain_modulus().value()));
cout << "expansion ratio = " << pir_params_.expansion_ratio << endl;
for (uint32_t i = 0; i < nvec.size(); i++) {
cout << "Server: " << i + 1 << "-th recursion level started " << endl;
vector<Ciphertext> expanded_query;
uint64_t n_i = nvec[i];
cout << "Server: n_i = " << n_i << endl;
cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
for (uint32_t j = 0; j < query[i].size(); j++){
uint64_t total = N;
if (j == query[i].size() - 1){
total = n_i % N;
}
cout << "-- expanding one query ctxt into " << total << " ctxts "<< endl;
vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()),
std::make_move_iterator(expanded_query_part.end()));
expanded_query_part.clear();
}
cout << "Server: expansion done " << endl;
if (expanded_query.size() != n_i) {
cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl;
}
/*
cout << "Checking expanded query " << endl;
Plaintext tempPt;
for (int h = 0 ; h < expanded_query.size(); h++){
cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", ";
client.decryptor_->decrypt(expanded_query[h], tempPt);
cout << tempPt.to_string() << endl;
}
cout << endl;
*/
// Transform expanded query to NTT, and ...
for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
}
// Transform plaintext to NTT. If database is pre-processed, can skip
if ((!is_db_preprocessed_) || i > 0) {
for (uint32_t jj = 0; jj < cur->size(); jj++) {
evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
}
}
for (uint64_t k = 0; k < product; k++) {
if ((*cur)[k].is_zero()){
cout << k + 1 << "/ " << product << "-th ptxt = 0 " << endl;
}
}
product /= n_i;
vector<Ciphertext> intermediateCtxts(product);
Ciphertext temp;
for (uint64_t k = 0; k < product; k++) {
evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
for (uint64_t j = 1; j < n_i; j++) {
evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
}
}
for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
// print intermediate ctxts?
//cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl;
}
if (i == nvec.size() - 1) {
return intermediateCtxts;
} else {
intermediate_plain.clear();
intermediate_plain.reserve(pir_params_.expansion_ratio * product);
cur = &intermediate_plain;
auto tempplain = util::allocate<Plaintext>(
pir_params_.expansion_ratio * product,
pool, coeff_count);
for (uint64_t rr = 0; rr < product; rr++) {
decompose_to_plaintexts_ptr(intermediateCtxts[rr],
tempplain.get() + rr * pir_params_.expansion_ratio, logt);
for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
auto offset = rr * pir_params_.expansion_ratio + jj;
intermediate_plain.emplace_back(tempplain[offset]);
}
}
product *= pir_params_.expansion_ratio; // multiply by expansion rate.
}
cout << "Server: " << i + 1 << "-th recursion level finished " << endl;
cout << endl;
}
cout << "reply generated! " << endl;
// This should never get here
assert(0);
vector<Ciphertext> fail(1);
return fail;
}