in native/src/seal/ckks.h [448:629]
void encode_internal(
const T *values, std::size_t values_size, parms_id_type parms_id, double scale, Plaintext &destination,
MemoryPoolHandle pool)
{
// Verify parameters.
auto context_data_ptr = context_.get_context_data(parms_id);
if (!context_data_ptr)
{
throw std::invalid_argument("parms_id is not valid for encryption parameters");
}
if (!values && values_size > 0)
{
throw std::invalid_argument("values cannot be null");
}
if (values_size > slots_)
{
throw std::invalid_argument("values_size is too large");
}
if (!pool)
{
throw std::invalid_argument("pool is uninitialized");
}
auto &context_data = *context_data_ptr;
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
std::size_t coeff_modulus_size = coeff_modulus.size();
std::size_t coeff_count = parms.poly_modulus_degree();
// Quick sanity check
if (!util::product_fits_in(coeff_modulus_size, coeff_count))
{
throw std::logic_error("invalid parameters");
}
// Check that scale is positive and not too large
if (scale <= 0 || (static_cast<int>(log2(scale)) + 1 >= context_data.total_coeff_modulus_bit_count()))
{
throw std::invalid_argument("scale out of bounds");
}
auto ntt_tables = context_data.small_ntt_tables();
// values_size is guaranteed to be no bigger than slots_
std::size_t n = util::mul_safe(slots_, std::size_t(2));
auto conj_values = util::allocate<std::complex<double>>(n, pool, 0);
for (std::size_t i = 0; i < values_size; i++)
{
conj_values[matrix_reps_index_map_[i]] = values[i];
// TODO: if values are real, the following values should be set to zero, and multiply results by 2.
conj_values[matrix_reps_index_map_[i + slots_]] = std::conj(values[i]);
}
double fix = scale / static_cast<double>(n);
fft_handler_.transform_from_rev(conj_values.get(), util::get_power_of_two(n), inv_root_powers_.get(), &fix);
double max_coeff = 0;
for (std::size_t i = 0; i < n; i++)
{
max_coeff = std::max<>(max_coeff, std::fabs(conj_values[i].real()));
}
// Verify that the values are not too large to fit in coeff_modulus
// Note that we have an extra + 1 for the sign bit
// Don't compute logarithmis of numbers less than 1
int max_coeff_bit_count = static_cast<int>(std::ceil(std::log2(std::max<>(max_coeff, 1.0)))) + 1;
if (max_coeff_bit_count >= context_data.total_coeff_modulus_bit_count())
{
throw std::invalid_argument("encoded values are too large");
}
double two_pow_64 = std::pow(2.0, 64);
// Resize destination to appropriate size
// Need to first set parms_id to zero, otherwise resize
// will throw an exception.
destination.parms_id() = parms_id_zero;
destination.resize(util::mul_safe(coeff_count, coeff_modulus_size));
// Use faster decomposition methods when possible
if (max_coeff_bit_count <= 64)
{
for (std::size_t i = 0; i < n; i++)
{
double coeffd = std::round(conj_values[i].real());
bool is_negative = std::signbit(coeffd);
std::uint64_t coeffu = static_cast<std::uint64_t>(std::fabs(coeffd));
if (is_negative)
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = util::negate_uint_mod(
util::barrett_reduce_64(coeffu, coeff_modulus[j]), coeff_modulus[j]);
}
}
else
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = util::barrett_reduce_64(coeffu, coeff_modulus[j]);
}
}
}
}
else if (max_coeff_bit_count <= 128)
{
for (std::size_t i = 0; i < n; i++)
{
double coeffd = std::round(conj_values[i].real());
bool is_negative = std::signbit(coeffd);
coeffd = std::fabs(coeffd);
std::uint64_t coeffu[2]{ static_cast<std::uint64_t>(std::fmod(coeffd, two_pow_64)),
static_cast<std::uint64_t>(coeffd / two_pow_64) };
if (is_negative)
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = util::negate_uint_mod(
util::barrett_reduce_128(coeffu, coeff_modulus[j]), coeff_modulus[j]);
}
}
else
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = util::barrett_reduce_128(coeffu, coeff_modulus[j]);
}
}
}
}
else
{
// Slow case
auto coeffu(util::allocate_uint(coeff_modulus_size, pool));
for (std::size_t i = 0; i < n; i++)
{
double coeffd = std::round(conj_values[i].real());
bool is_negative = std::signbit(coeffd);
coeffd = std::fabs(coeffd);
// We are at this point guaranteed to fit in the allocated space
util::set_zero_uint(coeff_modulus_size, coeffu.get());
auto coeffu_ptr = coeffu.get();
while (coeffd >= 1)
{
*coeffu_ptr++ = static_cast<std::uint64_t>(std::fmod(coeffd, two_pow_64));
coeffd /= two_pow_64;
}
// Next decompose this coefficient
context_data.rns_tool()->base_q()->decompose(coeffu.get(), pool);
// Finally replace the sign if necessary
if (is_negative)
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = util::negate_uint_mod(coeffu[j], coeff_modulus[j]);
}
}
else
{
for (std::size_t j = 0; j < coeff_modulus_size; j++)
{
destination[i + (j * coeff_count)] = coeffu[j];
}
}
}
}
// Transform to NTT domain
for (std::size_t i = 0; i < coeff_modulus_size; i++)
{
util::ntt_negacyclic_harvey(destination.data(i * coeff_count), ntt_tables[i]);
}
destination.parms_id() = parms_id;
destination.scale() = scale;
}