in native/src/seal/ckks.cpp [75:214]
void CKKSEncoder::encode_internal(
double value, parms_id_type parms_id, double scale, Plaintext &destination, MemoryPoolHandle pool)
{
// Verify parameters.
auto context_data_ptr = context_.get_context_data(parms_id);
if (!context_data_ptr)
{
throw invalid_argument("parms_id is not valid for encryption parameters");
}
if (!pool)
{
throw invalid_argument("pool is uninitialized");
}
auto &context_data = *context_data_ptr;
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t coeff_modulus_size = coeff_modulus.size();
size_t coeff_count = parms.poly_modulus_degree();
// Quick sanity check
if (!product_fits_in(coeff_modulus_size, coeff_count))
{
throw logic_error("invalid parameters");
}
// Check that scale is positive and not too large
if (scale <= 0 || (static_cast<int>(log2(scale)) >= context_data.total_coeff_modulus_bit_count()))
{
throw invalid_argument("scale out of bounds");
}
// Compute the scaled value
value *= scale;
int coeff_bit_count = static_cast<int>(log2(fabs(value))) + 2;
if (coeff_bit_count >= context_data.total_coeff_modulus_bit_count())
{
throw invalid_argument("encoded value is too large");
}
double two_pow_64 = pow(2.0, 64);
// Resize destination to appropriate size
// Need to first set parms_id to zero, otherwise resize
// will throw an exception.
destination.parms_id() = parms_id_zero;
destination.resize(coeff_count * coeff_modulus_size);
double coeffd = round(value);
bool is_negative = signbit(coeffd);
coeffd = fabs(coeffd);
// Use faster decomposition methods when possible
if (coeff_bit_count <= 64)
{
uint64_t coeffu = static_cast<uint64_t>(fabs(coeffd));
if (is_negative)
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(
destination.data() + (j * coeff_count), coeff_count,
negate_uint_mod(barrett_reduce_64(coeffu, coeff_modulus[j]), coeff_modulus[j]));
}
}
else
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(
destination.data() + (j * coeff_count), coeff_count,
barrett_reduce_64(coeffu, coeff_modulus[j]));
}
}
}
else if (coeff_bit_count <= 128)
{
uint64_t coeffu[2]{ static_cast<uint64_t>(fmod(coeffd, two_pow_64)),
static_cast<uint64_t>(coeffd / two_pow_64) };
if (is_negative)
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(
destination.data() + (j * coeff_count), coeff_count,
negate_uint_mod(barrett_reduce_128(coeffu, coeff_modulus[j]), coeff_modulus[j]));
}
}
else
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(
destination.data() + (j * coeff_count), coeff_count,
barrett_reduce_128(coeffu, coeff_modulus[j]));
}
}
}
else
{
// Slow case
auto coeffu(allocate_uint(coeff_modulus_size, pool));
// We are at this point guaranteed to fit in the allocated space
set_zero_uint(coeff_modulus_size, coeffu.get());
auto coeffu_ptr = coeffu.get();
while (coeffd >= 1)
{
*coeffu_ptr++ = static_cast<uint64_t>(fmod(coeffd, two_pow_64));
coeffd /= two_pow_64;
}
// Next decompose this coefficient
context_data.rns_tool()->base_q()->decompose(coeffu.get(), pool);
// Finally replace the sign if necessary
if (is_negative)
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(
destination.data() + (j * coeff_count), coeff_count,
negate_uint_mod(coeffu[j], coeff_modulus[j]));
}
}
else
{
for (size_t j = 0; j < coeff_modulus_size; j++)
{
fill_n(destination.data() + (j * coeff_count), coeff_count, coeffu[j]);
}
}
}
destination.parms_id() = parms_id;
destination.scale() = scale;
}