in native/src/seal/context.cpp [135:420]
SEALContext::ContextData SEALContext::validate(EncryptionParameters parms)
{
ContextData context_data(parms, pool_);
context_data.qualifiers_.parameter_error = error_type::success;
// Is a scheme set?
if (parms.scheme() == scheme_type::none)
{
context_data.qualifiers_.parameter_error = error_type::invalid_scheme;
return context_data;
}
auto &coeff_modulus = parms.coeff_modulus();
auto &plain_modulus = parms.plain_modulus();
// The number of coeff moduli is restricted to 64 to prevent unexpected behaviors
if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN)
{
context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_size;
return context_data;
}
size_t coeff_modulus_size = coeff_modulus.size();
for (size_t i = 0; i < coeff_modulus_size; i++)
{
// Check coefficient moduli bounds
if (coeff_modulus[i].value() >> SEAL_USER_MOD_BIT_COUNT_MAX ||
!(coeff_modulus[i].value() >> (SEAL_USER_MOD_BIT_COUNT_MIN - 1)))
{
context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_bit_count;
return context_data;
}
}
// Compute the product of all coeff moduli
context_data.total_coeff_modulus_ = allocate_uint(coeff_modulus_size, pool_);
auto coeff_modulus_values(allocate_uint(coeff_modulus_size, pool_));
for (size_t i = 0; i < coeff_modulus_size; i++)
{
coeff_modulus_values[i] = coeff_modulus[i].value();
}
multiply_many_uint64(
coeff_modulus_values.get(), coeff_modulus_size, context_data.total_coeff_modulus_.get(), pool_);
context_data.total_coeff_modulus_bit_count_ =
get_significant_bit_count_uint(context_data.total_coeff_modulus_.get(), coeff_modulus_size);
// Check polynomial modulus degree and create poly_modulus
size_t poly_modulus_degree = parms.poly_modulus_degree();
if (poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX)
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_poly_modulus_degree;
return context_data;
}
int coeff_count_power = get_power_of_two(poly_modulus_degree);
if (coeff_count_power < 0)
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_poly_modulus_degree_non_power_of_two;
return context_data;
}
// Quick sanity check
if (!product_fits_in(coeff_modulus_size, poly_modulus_degree))
{
context_data.qualifiers_.parameter_error = error_type::invalid_parameters_too_large;
return context_data;
}
// Polynomial modulus X^(2^k) + 1 is guaranteed at this point
context_data.qualifiers_.using_fft = true;
// Assume parameters satisfy desired security level
context_data.qualifiers_.sec_level = sec_level_;
// Check if the parameters are secure according to HomomorphicEncryption.org security standard
if (context_data.total_coeff_modulus_bit_count_ > CoeffModulus::MaxBitCount(poly_modulus_degree, sec_level_))
{
// Not secure according to HomomorphicEncryption.org security standard
context_data.qualifiers_.sec_level = sec_level_type::none;
if (sec_level_ != sec_level_type::none)
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_parameters_insecure;
return context_data;
}
}
// Set up RNSBase for coeff_modulus
// RNSBase's constructor may fail due to:
// (1) coeff_mod not coprime
// (2) cannot find inverse of punctured products (because of (1))
Pointer<RNSBase> coeff_modulus_base;
try
{
coeff_modulus_base = allocate<RNSBase>(pool_, coeff_modulus, pool_);
}
catch (const invalid_argument &)
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::failed_creating_rns_base;
return context_data;
}
// Can we use NTT with coeff_modulus?
context_data.qualifiers_.using_ntt = true;
try
{
CreateNTTTables(coeff_count_power, coeff_modulus, context_data.small_ntt_tables_, pool_);
}
catch (const invalid_argument &)
{
context_data.qualifiers_.using_ntt = false;
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_no_ntt;
return context_data;
}
if (parms.scheme() == scheme_type::bfv)
{
// Plain modulus must be at least 2 and at most 60 bits
if (plain_modulus.value() >> SEAL_PLAIN_MOD_BIT_COUNT_MAX ||
!(plain_modulus.value() >> (SEAL_PLAIN_MOD_BIT_COUNT_MIN - 1)))
{
context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_bit_count;
return context_data;
}
// Check that all coeff moduli are relatively prime to plain_modulus
for (size_t i = 0; i < coeff_modulus_size; i++)
{
if (!are_coprime(coeff_modulus[i].value(), plain_modulus.value()))
{
context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_coprimality;
return context_data;
}
}
// Check that plain_modulus is smaller than total coeff modulus
if (!is_less_than_uint(
plain_modulus.data(), plain_modulus.uint64_count(), context_data.total_coeff_modulus_.get(),
coeff_modulus_size))
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_too_large;
return context_data;
}
// Can we use batching? (NTT with plain_modulus)
context_data.qualifiers_.using_batching = true;
try
{
CreateNTTTables(coeff_count_power, { plain_modulus }, context_data.plain_ntt_tables_, pool_);
}
catch (const invalid_argument &)
{
context_data.qualifiers_.using_batching = false;
}
// Check for plain_lift
// If all the small coefficient moduli are larger than plain modulus, we can quickly
// lift plain coefficients to RNS form
context_data.qualifiers_.using_fast_plain_lift = true;
for (size_t i = 0; i < coeff_modulus_size; i++)
{
context_data.qualifiers_.using_fast_plain_lift &= (coeff_modulus[i].value() > plain_modulus.value());
}
// Calculate coeff_div_plain_modulus (BFV-"Delta") and the remainder upper_half_increment
auto temp_coeff_div_plain_modulus = allocate_uint(coeff_modulus_size, pool_);
context_data.coeff_div_plain_modulus_ = allocate<MultiplyUIntModOperand>(coeff_modulus_size, pool_);
context_data.upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
auto wide_plain_modulus(duplicate_uint_if_needed(
plain_modulus.data(), plain_modulus.uint64_count(), coeff_modulus_size, false, pool_));
divide_uint(
context_data.total_coeff_modulus_.get(), wide_plain_modulus.get(), coeff_modulus_size,
temp_coeff_div_plain_modulus.get(), context_data.upper_half_increment_.get(), pool_);
// Store the non-RNS form of upper_half_increment for BFV encryption
context_data.coeff_modulus_mod_plain_modulus_ = context_data.upper_half_increment_[0];
// Decompose coeff_div_plain_modulus into RNS factors
coeff_modulus_base->decompose(temp_coeff_div_plain_modulus.get(), pool_);
for (size_t i = 0; i < coeff_modulus_size; i++)
{
context_data.coeff_div_plain_modulus_[i].set(
temp_coeff_div_plain_modulus[i], coeff_modulus_base->base()[i]);
}
// Decompose upper_half_increment into RNS factors
coeff_modulus_base->decompose(context_data.upper_half_increment_.get(), pool_);
// Calculate (plain_modulus + 1) / 2.
context_data.plain_upper_half_threshold_ = (plain_modulus.value() + 1) >> 1;
// Calculate coeff_modulus - plain_modulus.
context_data.plain_upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
if (context_data.qualifiers_.using_fast_plain_lift)
{
// Calculate coeff_modulus[i] - plain_modulus if using_fast_plain_lift
for (size_t i = 0; i < coeff_modulus_size; i++)
{
context_data.plain_upper_half_increment_[i] = coeff_modulus[i].value() - plain_modulus.value();
}
}
else
{
sub_uint(
context_data.total_coeff_modulus(), wide_plain_modulus.get(), coeff_modulus_size,
context_data.plain_upper_half_increment_.get());
}
}
else if (parms.scheme() == scheme_type::ckks)
{
// Check that plain_modulus is set to zero
if (!plain_modulus.is_zero())
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_nonzero;
return context_data;
}
// When using CKKS batching (BatchEncoder) is always enabled
context_data.qualifiers_.using_batching = true;
// Cannot use fast_plain_lift for CKKS since the plaintext coefficients
// can easily be larger than coefficient moduli
context_data.qualifiers_.using_fast_plain_lift = false;
// Calculate 2^64 / 2 (most negative plaintext coefficient value)
context_data.plain_upper_half_threshold_ = uint64_t(1) << 63;
// Calculate plain_upper_half_increment = 2^64 mod coeff_modulus for CKKS plaintexts
context_data.plain_upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
for (size_t i = 0; i < coeff_modulus_size; i++)
{
uint64_t tmp = barrett_reduce_64(uint64_t(1) << 63, coeff_modulus[i]);
context_data.plain_upper_half_increment_[i] =
multiply_uint_mod(tmp, sub_safe(coeff_modulus[i].value(), uint64_t(2)), coeff_modulus[i]);
}
// Compute the upper_half_threshold for this modulus.
context_data.upper_half_threshold_ = allocate_uint(coeff_modulus_size, pool_);
increment_uint(
context_data.total_coeff_modulus(), coeff_modulus_size, context_data.upper_half_threshold_.get());
right_shift_uint(
context_data.upper_half_threshold_.get(), 1, coeff_modulus_size,
context_data.upper_half_threshold_.get());
}
else
{
context_data.qualifiers_.parameter_error = error_type::invalid_scheme;
return context_data;
}
// Create RNSTool
// RNSTool's constructor may fail due to:
// (1) auxiliary base being too large
// (2) cannot find inverse of punctured products in auxiliary base
try
{
context_data.rns_tool_ =
allocate<RNSTool>(pool_, poly_modulus_degree, *coeff_modulus_base, plain_modulus, pool_);
}
catch (const exception &)
{
// Parameters are not valid
context_data.qualifiers_.parameter_error = error_type::failed_creating_rns_tool;
return context_data;
}
// Check whether the coefficient modulus consists of a set of primes that are in decreasing order
context_data.qualifiers_.using_descending_modulus_chain = true;
for (size_t i = 0; i < coeff_modulus_size - 1; i++)
{
context_data.qualifiers_.using_descending_modulus_chain &=
(coeff_modulus[i].value() > coeff_modulus[i + 1].value());
}
// Create GaloisTool
context_data.galois_tool_ = allocate<GaloisTool>(pool_, coeff_count_power, pool_);
// Done with validation and pre-computations
return context_data;
}