in native/src/seal/util/rns.cpp [502:694]
void RNSTool::initialize(size_t poly_modulus_degree, const RNSBase &q, const Modulus &t)
{
// Return if q is out of bounds
if (q.size() < SEAL_COEFF_MOD_COUNT_MIN || q.size() > SEAL_COEFF_MOD_COUNT_MAX)
{
throw invalid_argument("rnsbase is invalid");
}
// Return if coeff_count is not a power of two or out of bounds
int coeff_count_power = get_power_of_two(poly_modulus_degree);
if (coeff_count_power < 0 || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX ||
poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN)
{
throw invalid_argument("poly_modulus_degree is invalid");
}
t_ = t;
coeff_count_ = poly_modulus_degree;
// Allocate memory for the bases q, B, Bsk, Bsk U m_tilde, t_gamma
size_t base_q_size = q.size();
// In some cases we might need to increase the size of the base B by one, namely we require
// K * n * t * q^2 < q * prod(B) * m_sk, where K takes into account cross terms when larger size ciphertexts
// are used, and n is the "delta factor" for the ring. We reserve 32 bits for K * n. Here the coeff modulus
// primes q_i are bounded to be SEAL_USER_MOD_BIT_COUNT_MAX (60) bits, and all primes in B and m_sk are
// SEAL_INTERNAL_MOD_BIT_COUNT (61) bits.
int total_coeff_bit_count = get_significant_bit_count_uint(q.base_prod(), q.size());
size_t base_B_size = base_q_size;
if (32 + t_.bit_count() + total_coeff_bit_count >=
SEAL_INTERNAL_MOD_BIT_COUNT * safe_cast<int>(base_q_size) + SEAL_INTERNAL_MOD_BIT_COUNT)
{
base_B_size++;
}
size_t base_Bsk_size = add_safe(base_B_size, size_t(1));
size_t base_Bsk_m_tilde_size = add_safe(base_Bsk_size, size_t(1));
size_t base_t_gamma_size = 0;
// Size check
if (!product_fits_in(coeff_count_, base_Bsk_m_tilde_size))
{
throw logic_error("invalid parameters");
}
// Sample primes for B and two more primes: m_sk and gamma
auto baseconv_primes = get_primes(coeff_count_, SEAL_INTERNAL_MOD_BIT_COUNT, base_Bsk_m_tilde_size);
auto baseconv_primes_iter = baseconv_primes.cbegin();
m_sk_ = *baseconv_primes_iter++;
gamma_ = *baseconv_primes_iter++;
vector<Modulus> base_B_primes;
copy_n(baseconv_primes_iter, base_B_size, back_inserter(base_B_primes));
// Set m_tilde_ to a non-prime value
m_tilde_ = uint64_t(1) << 32;
// Populate the base arrays
base_q_ = allocate<RNSBase>(pool_, q, pool_);
base_B_ = allocate<RNSBase>(pool_, base_B_primes, pool_);
base_Bsk_ = allocate<RNSBase>(pool_, base_B_->extend(m_sk_));
base_Bsk_m_tilde_ = allocate<RNSBase>(pool_, base_Bsk_->extend(m_tilde_));
// Set up t-gamma base if t_ is non-zero (using BFV)
if (!t_.is_zero())
{
base_t_gamma_size = 2;
base_t_gamma_ = allocate<RNSBase>(pool_, vector<Modulus>{ t_, gamma_ }, pool_);
}
// Generate the Bsk NTTTables; these are used for NTT after base extension to Bsk
try
{
CreateNTTTables(
coeff_count_power, vector<Modulus>(base_Bsk_->base(), base_Bsk_->base() + base_Bsk_size),
base_Bsk_ntt_tables_, pool_);
}
catch (const logic_error &)
{
throw logic_error("invalid rns bases");
}
// Set up BaseConverter for q --> Bsk
base_q_to_Bsk_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_Bsk_, pool_);
// Set up BaseConverter for q --> {m_tilde}
base_q_to_m_tilde_conv_ = allocate<BaseConverter>(pool_, *base_q_, RNSBase({ m_tilde_ }, pool_), pool_);
// Set up BaseConverter for B --> q
base_B_to_q_conv_ = allocate<BaseConverter>(pool_, *base_B_, *base_q_, pool_);
// Set up BaseConverter for B --> {m_sk}
base_B_to_m_sk_conv_ = allocate<BaseConverter>(pool_, *base_B_, RNSBase({ m_sk_ }, pool_), pool_);
if (base_t_gamma_)
{
// Set up BaseConverter for q --> {t, gamma}
base_q_to_t_gamma_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_t_gamma_, pool_);
}
// Compute prod(B) mod q
prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
});
uint64_t temp;
// Compute prod(q)^(-1) mod Bsk
inv_prod_q_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
for (size_t i = 0; i < base_Bsk_size; i++)
{
temp = modulo_uint(base_q_->base_prod(), base_q_size, (*base_Bsk_)[i]);
if (!try_invert_uint_mod(temp, (*base_Bsk_)[i], temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_q_mod_Bsk_[i].set(temp, (*base_Bsk_)[i]);
}
// Compute prod(B)^(-1) mod m_sk
temp = modulo_uint(base_B_->base_prod(), base_B_size, m_sk_);
if (!try_invert_uint_mod(temp, m_sk_, temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_B_mod_m_sk_.set(temp, m_sk_);
// Compute m_tilde^(-1) mod Bsk
inv_m_tilde_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
SEAL_ITERATE(iter(inv_m_tilde_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
if (!try_invert_uint_mod(barrett_reduce_64(m_tilde_.value(), get<1>(I)), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
// Compute prod(q)^(-1) mod m_tilde
temp = modulo_uint(base_q_->base_prod(), base_q_size, m_tilde_);
if (!try_invert_uint_mod(temp, m_tilde_, temp))
{
throw logic_error("invalid rns bases");
}
neg_inv_prod_q_mod_m_tilde_.set(negate_uint_mod(temp, m_tilde_), m_tilde_);
// Compute prod(q) mod Bsk
prod_q_mod_Bsk_ = allocate_uint(base_Bsk_size, pool_);
SEAL_ITERATE(iter(prod_q_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
get<0>(I) = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
});
if (base_t_gamma_)
{
// Compute gamma^(-1) mod t
if (!try_invert_uint_mod(barrett_reduce_64(gamma_.value(), t_), t_, temp))
{
throw logic_error("invalid rns bases");
}
inv_gamma_mod_t_.set(temp, t_);
// Compute prod({t, gamma}) mod q
prod_t_gamma_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size, pool_);
SEAL_ITERATE(iter(prod_t_gamma_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I).set(
multiply_uint_mod((*base_t_gamma_)[0].value(), (*base_t_gamma_)[1].value(), get<1>(I)),
get<1>(I));
});
// Compute -prod(q)^(-1) mod {t, gamma}
neg_inv_q_mod_t_gamma_ = allocate<MultiplyUIntModOperand>(base_t_gamma_size, pool_);
SEAL_ITERATE(iter(neg_inv_q_mod_t_gamma_, base_t_gamma_->base()), base_t_gamma_size, [&](auto I) {
get<0>(I).operand = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
if (!try_invert_uint_mod(get<0>(I).operand, get<1>(I), get<0>(I).operand))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(negate_uint_mod(get<0>(I).operand, get<1>(I)), get<1>(I));
});
}
// Compute q[last]^(-1) mod q[i] for i = 0..last-1
// This is used by modulus switching and rescaling
inv_q_last_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size - 1, pool_);
SEAL_ITERATE(iter(inv_q_last_mod_q_, base_q_->base()), base_q_size - 1, [&](auto I) {
if (!try_invert_uint_mod((*base_q_)[base_q_size - 1].value(), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
}