PirReply PIRServer::generate_reply()

in pir_server.cpp [130:258]


PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {

    vector<uint64_t> nvec = pir_params_.nvec;
    uint64_t product = 1;

    for (uint32_t i = 0; i < nvec.size(); i++) {
        product *= nvec[i];
    }

    auto coeff_count = params_.poly_modulus_degree();

    vector<Plaintext> *cur = db_.get();
    vector<Plaintext> intermediate_plain; // decompose....

    auto pool = MemoryManager::GetPool();


    int N = params_.poly_modulus_degree();

    int logt = floor(log2(params_.plain_modulus().value()));

    cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
    for (uint32_t i = 0; i < nvec.size(); i++) {
        cout << "Server: " << i + 1 << "-th recursion level started " << endl; 


        vector<Ciphertext> expanded_query; 

        uint64_t n_i = nvec[i];
        cout << "Server: n_i = " << n_i << endl; 
        cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
        for (uint32_t j = 0; j < query[i].size(); j++){
            uint64_t total = N; 
            if (j == query[i].size() - 1){
                total = n_i % N; 
            }
            cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
            expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
                    std::make_move_iterator(expanded_query_part.end()));
            expanded_query_part.clear(); 
        }
        cout << "Server: expansion done " << endl; 
        if (expanded_query.size() != n_i) {
            cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
        }    

        /*
        cout << "Checking expanded query " << endl; 
        Plaintext tempPt; 
        for (int h = 0 ; h < expanded_query.size(); h++){
            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
            client.decryptor_->decrypt(expanded_query[h], tempPt); 
            cout << tempPt.to_string()  << endl; 
        }
        cout << endl;
        */

        // Transform expanded query to NTT, and ...
        for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
            evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
        }

        // Transform plaintext to NTT. If database is pre-processed, can skip
        if ((!is_db_preprocessed_) || i > 0) {
            for (uint32_t jj = 0; jj < cur->size(); jj++) {
                evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
            }
        }

        for (uint64_t k = 0; k < product; k++) {
            if ((*cur)[k].is_zero()){
                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl; 
            }
        }

        product /= n_i;

        vector<Ciphertext> intermediateCtxts(product);
        Ciphertext temp;

        for (uint64_t k = 0; k < product; k++) {

            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);

            for (uint64_t j = 1; j < n_i; j++) {
                evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
                evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
            }
        }

        for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
            evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
            // print intermediate ctxts? 
            //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
        }

        if (i == nvec.size() - 1) {
            return intermediateCtxts;
        } else {
            intermediate_plain.clear();
            intermediate_plain.reserve(pir_params_.expansion_ratio * product);
            cur = &intermediate_plain;

            auto tempplain = util::allocate<Plaintext>(
                pir_params_.expansion_ratio * product,
                pool, coeff_count);

            for (uint64_t rr = 0; rr < product; rr++) {

                decompose_to_plaintexts_ptr(intermediateCtxts[rr],
                    tempplain.get() + rr * pir_params_.expansion_ratio, logt);

                for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
                    auto offset = rr * pir_params_.expansion_ratio + jj;
                    intermediate_plain.emplace_back(tempplain[offset]);
                }
            }
            product *= pir_params_.expansion_ratio; // multiply by expansion rate.
        }
        cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
        cout << endl;
    }
    cout << "reply generated!  " << endl;
    // This should never get here
    assert(0);
    vector<Ciphertext> fail(1);
    return fail;
}