in native/src/seal/evaluator.cpp [2063:2307]
void Evaluator::switch_key_inplace(
Ciphertext &encrypted, ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, size_t kswitch_keys_index,
MemoryPoolHandle pool) const
{
auto parms_id = encrypted.parms_id();
auto &context_data = *context_.get_context_data(parms_id);
auto &parms = context_data.parms();
auto &key_context_data = *context_.key_context_data();
auto &key_parms = key_context_data.parms();
auto scheme = parms.scheme();
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
{
throw invalid_argument("encrypted is not valid for encryption parameters");
}
if (!target_iter)
{
throw invalid_argument("target_iter");
}
if (!context_.using_keyswitching())
{
throw logic_error("keyswitching is not supported by the context");
}
// Don't validate all of kswitch_keys but just check the parms_id.
if (kswitch_keys.parms_id() != context_.key_parms_id())
{
throw invalid_argument("parameter mismatch");
}
if (kswitch_keys_index >= kswitch_keys.data().size())
{
throw out_of_range("kswitch_keys_index");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
if (scheme == scheme_type::bfv && encrypted.is_ntt_form())
{
throw invalid_argument("BFV encrypted cannot be in NTT form");
}
if (scheme == scheme_type::ckks && !encrypted.is_ntt_form())
{
throw invalid_argument("CKKS encrypted must be in NTT form");
}
// Extract encryption parameters.
size_t coeff_count = parms.poly_modulus_degree();
size_t decomp_modulus_size = parms.coeff_modulus().size();
auto &key_modulus = key_parms.coeff_modulus();
size_t key_modulus_size = key_modulus.size();
size_t rns_modulus_size = decomp_modulus_size + 1;
auto key_ntt_tables = iter(key_context_data.small_ntt_tables());
auto modswitch_factors = key_context_data.rns_tool()->inv_q_last_mod_q();
// Size check
if (!product_fits_in(coeff_count, rns_modulus_size, size_t(2)))
{
throw logic_error("invalid parameters");
}
// Prepare input
auto &key_vector = kswitch_keys.data()[kswitch_keys_index];
size_t key_component_count = key_vector[0].data().size();
// Check only the used component in KSwitchKeys.
for (auto &each_key : key_vector)
{
if (!is_metadata_valid_for(each_key, context_) || !is_buffer_valid(each_key))
{
throw invalid_argument("kswitch_keys is not valid for encryption parameters");
}
}
// Create a copy of target_iter
SEAL_ALLOCATE_GET_RNS_ITER(t_target, coeff_count, decomp_modulus_size, pool);
set_uint(target_iter, decomp_modulus_size * coeff_count, t_target);
// In CKKS t_target is in NTT form; switch back to normal form
if (scheme == scheme_type::ckks)
{
inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables);
}
// Temporary result
auto t_poly_prod(allocate_zero_poly_array(key_component_count, coeff_count, rns_modulus_size, pool));
SEAL_ITERATE(iter(size_t(0)), rns_modulus_size, [&](auto I) {
size_t key_index = (I == decomp_modulus_size ? key_modulus_size - 1 : I);
// Product of two numbers is up to 60 + 60 = 120 bits, so we can sum up to 256 of them without reduction.
size_t lazy_reduction_summand_bound = size_t(SEAL_MULTIPLY_ACCUMULATE_USER_MOD_MAX);
size_t lazy_reduction_counter = lazy_reduction_summand_bound;
// Allocate memory for a lazy accumulator (128-bit coefficients)
auto t_poly_lazy(allocate_zero_poly_array(key_component_count, coeff_count, 2, pool));
// Semantic misuse of PolyIter; this is really pointing to the data for a single RNS factor
PolyIter accumulator_iter(t_poly_lazy.get(), 2, coeff_count);
// Multiply with keys and perform lazy reduction on product's coefficients
SEAL_ITERATE(iter(size_t(0)), decomp_modulus_size, [&](auto J) {
SEAL_ALLOCATE_GET_COEFF_ITER(t_ntt, coeff_count, pool);
ConstCoeffIter t_operand;
// RNS-NTT form exists in input
if ((scheme == scheme_type::ckks) && (I == J))
{
t_operand = target_iter[J];
}
// Perform RNS-NTT conversion
else
{
// No need to perform RNS conversion (modular reduction)
if (key_modulus[J] <= key_modulus[key_index])
{
set_uint(t_target[J], coeff_count, t_ntt);
}
// Perform RNS conversion (modular reduction)
else
{
modulo_poly_coeffs(t_target[J], coeff_count, key_modulus[key_index], t_ntt);
}
// NTT conversion lazy outputs in [0, 4q)
ntt_negacyclic_harvey_lazy(t_ntt, key_ntt_tables[key_index]);
t_operand = t_ntt;
}
// Multiply with keys and modular accumulate products in a lazy fashion
SEAL_ITERATE(iter(key_vector[J].data(), accumulator_iter), key_component_count, [&](auto K) {
if (!lazy_reduction_counter)
{
SEAL_ITERATE(iter(t_operand, get<0>(K)[key_index], get<1>(K)), coeff_count, [&](auto L) {
unsigned long long qword[2]{ 0, 0 };
multiply_uint64(get<0>(L), get<1>(L), qword);
// Accumulate product of t_operand and t_key_acc to t_poly_lazy and reduce
add_uint128(qword, get<2>(L).ptr(), qword);
get<2>(L)[0] = barrett_reduce_128(qword, key_modulus[key_index]);
get<2>(L)[1] = 0;
});
}
else
{
// Same as above but no reduction
SEAL_ITERATE(iter(t_operand, get<0>(K)[key_index], get<1>(K)), coeff_count, [&](auto L) {
unsigned long long qword[2]{ 0, 0 };
multiply_uint64(get<0>(L), get<1>(L), qword);
add_uint128(qword, get<2>(L).ptr(), qword);
get<2>(L)[0] = qword[0];
get<2>(L)[1] = qword[1];
});
}
});
if (!--lazy_reduction_counter)
{
lazy_reduction_counter = lazy_reduction_summand_bound;
}
});
// PolyIter pointing to the destination t_poly_prod, shifted to the appropriate modulus
PolyIter t_poly_prod_iter(t_poly_prod.get() + (I * coeff_count), coeff_count, rns_modulus_size);
// Final modular reduction
SEAL_ITERATE(iter(accumulator_iter, t_poly_prod_iter), key_component_count, [&](auto K) {
if (lazy_reduction_counter == lazy_reduction_summand_bound)
{
SEAL_ITERATE(iter(get<0>(K), *get<1>(K)), coeff_count, [&](auto L) {
get<1>(L) = static_cast<uint64_t>(*get<0>(L));
});
}
else
{
// Same as above except need to still do reduction
SEAL_ITERATE(iter(get<0>(K), *get<1>(K)), coeff_count, [&](auto L) {
get<1>(L) = barrett_reduce_128(get<0>(L).ptr(), key_modulus[key_index]);
});
}
});
});
// Accumulated products are now stored in t_poly_prod
// Perform modulus switching with scaling
PolyIter t_poly_prod_iter(t_poly_prod.get(), coeff_count, rns_modulus_size);
SEAL_ITERATE(iter(encrypted, t_poly_prod_iter), key_component_count, [&](auto I) {
// Lazy reduction; this needs to be then reduced mod qi
CoeffIter t_last(get<1>(I)[decomp_modulus_size]);
inverse_ntt_negacyclic_harvey_lazy(t_last, key_ntt_tables[key_modulus_size - 1]);
// Add (p-1)/2 to change from flooring to rounding.
uint64_t qk = key_modulus[key_modulus_size - 1].value();
uint64_t qk_half = qk >> 1;
SEAL_ITERATE(t_last, coeff_count, [&](auto &J) {
J = barrett_reduce_64(J + qk_half, key_modulus[key_modulus_size - 1]);
});
SEAL_ITERATE(iter(I, key_modulus, key_ntt_tables, modswitch_factors), decomp_modulus_size, [&](auto J) {
SEAL_ALLOCATE_GET_COEFF_ITER(t_ntt, coeff_count, pool);
// (ct mod 4qk) mod qi
uint64_t qi = get<1>(J).value();
if (qk > qi)
{
// This cannot be spared. NTT only tolerates input that is less than 4*modulus (i.e. qk <=4*qi).
modulo_poly_coeffs(t_last, coeff_count, get<1>(J), t_ntt);
}
else
{
set_uint(t_last, coeff_count, t_ntt);
}
// Lazy substraction, results in [0, 2*qi), since fix is in [0, qi].
uint64_t fix = qi - barrett_reduce_64(qk_half, get<1>(J));
SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; });
uint64_t qi_lazy = qi << 1; // some multiples of qi
if (scheme == scheme_type::ckks)
{
// This ntt_negacyclic_harvey_lazy results in [0, 4*qi).
ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J));
#if SEAL_USER_MOD_BIT_COUNT_MAX > 60
// Reduce from [0, 4qi) to [0, 2qi)
SEAL_ITERATE(t_ntt, coeff_count, [&](auto &K) { K -= SEAL_COND_SELECT(K >= qi_lazy, qi_lazy, 0); });
#else
// Since SEAL uses at most 60bit moduli, 8*qi < 2^63.
qi_lazy = qi << 2;
#endif
}
else if (scheme == scheme_type::bfv)
{
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
}
// ((ct mod qi) - (ct mod qk)) mod qi
SEAL_ITERATE(iter(get<0, 1>(J), t_ntt), coeff_count, [&](auto K) { get<0>(K) += qi_lazy - get<1>(K); });
// qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi
multiply_poly_scalar_coeffmod(get<0, 1>(J), coeff_count, get<3>(J), get<1>(J), get<0, 1>(J));
add_poly_coeffmod(get<0, 1>(J), get<0, 0>(J), coeff_count, get<1>(J), get<0, 0>(J));
});
});
}