in native/src/seal/util/uintarithmod.cpp [15:235]
bool try_invert_uint_mod(
const uint64_t *operand, const uint64_t *modulus, size_t uint64_count, uint64_t *result, MemoryPool &pool)
{
#ifdef SEAL_DEBUG
if (!operand)
{
throw invalid_argument("operand");
}
if (!modulus)
{
throw invalid_argument("modulus");
}
if (!uint64_count)
{
throw invalid_argument("uint64_count");
}
if (!result)
{
throw invalid_argument("result");
}
if (is_greater_than_or_equal_uint(operand, modulus, uint64_count))
{
throw logic_error("operand");
}
#endif
// Cannot invert 0.
int bit_count = get_significant_bit_count_uint(operand, uint64_count);
if (bit_count == 0)
{
return false;
}
// If it is 1, then its invert is itself.
if (bit_count == 1)
{
set_uint(1, uint64_count, result);
return true;
}
auto alloc_anchor(allocate_uint(7 * uint64_count, pool));
// Construct a mutable copy of operand and modulus, with numerator being modulus
// and operand being denominator. Notice that numerator > denominator.
uint64_t *numerator = alloc_anchor.get();
set_uint(modulus, uint64_count, numerator);
uint64_t *denominator = numerator + uint64_count;
set_uint(operand, uint64_count, denominator);
// Create space to store difference.
uint64_t *difference = denominator + uint64_count;
// Determine highest bit index of each.
int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count);
int denominator_bits = get_significant_bit_count_uint(denominator, uint64_count);
// Create space to store quotient.
uint64_t *quotient = difference + uint64_count;
// Create three sign/magnitude values to store coefficients.
// Initialize invert_prior to +0 and invert_curr to +1.
uint64_t *invert_prior = quotient + uint64_count;
set_zero_uint(uint64_count, invert_prior);
bool invert_prior_positive = true;
uint64_t *invert_curr = invert_prior + uint64_count;
set_uint(1, uint64_count, invert_curr);
bool invert_curr_positive = true;
uint64_t *invert_next = invert_curr + uint64_count;
bool invert_next_positive = true;
// Perform extended Euclidean algorithm.
while (true)
{
// NOTE: Numerator is > denominator.
// Only perform computation up to last non-zero uint64s.
size_t division_uint64_count = static_cast<size_t>(divide_round_up(numerator_bits, bits_per_uint64));
// Shift denominator to bring MSB in alignment with MSB of numerator.
int denominator_shift = numerator_bits - denominator_bits;
left_shift_uint(denominator, denominator_shift, division_uint64_count, denominator);
denominator_bits += denominator_shift;
// Clear quotient.
set_zero_uint(uint64_count, quotient);
// Perform bit-wise division algorithm.
int remaining_shifts = denominator_shift;
while (numerator_bits == denominator_bits)
{
// NOTE: MSBs of numerator and denominator are aligned.
// Even though MSB of numerator and denominator are aligned,
// still possible numerator < denominator.
if (sub_uint(numerator, denominator, division_uint64_count, difference))
{
// numerator < denominator and MSBs are aligned, so current
// quotient bit is zero and next one is definitely one.
if (remaining_shifts == 0)
{
// No shifts remain and numerator < denominator so done.
break;
}
// Effectively shift numerator left by 1 by instead adding
// numerator to difference (to prevent overflow in numerator).
add_uint(difference, numerator, division_uint64_count, difference);
// Adjust quotient and remaining shifts as a result of shifting numerator.
left_shift_uint(quotient, 1, division_uint64_count, quotient);
remaining_shifts--;
}
// Difference is the new numerator with denominator subtracted.
// Update quotient to reflect subtraction.
*quotient |= 1;
// Determine amount to shift numerator to bring MSB in alignment
// with denominator.
numerator_bits = get_significant_bit_count_uint(difference, division_uint64_count);
int numerator_shift = denominator_bits - numerator_bits;
if (numerator_shift > remaining_shifts)
{
// Clip the maximum shift to determine only the integer
// (as opposed to fractional) bits.
numerator_shift = remaining_shifts;
}
// Shift and update numerator.
if (numerator_bits > 0)
{
left_shift_uint(difference, numerator_shift, division_uint64_count, numerator);
numerator_bits += numerator_shift;
}
else
{
// Difference is zero so no need to shift, just set to zero.
set_zero_uint(division_uint64_count, numerator);
}
// Adjust quotient and remaining shifts as a result of
// shifting numerator.
left_shift_uint(quotient, numerator_shift, division_uint64_count, quotient);
remaining_shifts -= numerator_shift;
}
// Correct for shifting of denominator.
right_shift_uint(denominator, denominator_shift, division_uint64_count, denominator);
denominator_bits -= denominator_shift;
// We are done if remainder (which is stored in numerator) is zero.
if (numerator_bits == 0)
{
break;
}
// Correct for shifting of denominator.
right_shift_uint(numerator, denominator_shift, division_uint64_count, numerator);
numerator_bits -= denominator_shift;
// Integrate quotient with invert coefficients.
// Calculate: invert_prior + -quotient * invert_curr
multiply_truncate_uint(quotient, invert_curr, uint64_count, invert_next);
invert_next_positive = !invert_curr_positive;
if (invert_prior_positive == invert_next_positive)
{
// If both sides of add have same sign, then simply add and
// do not need to worry about overflow due to known limits
// on the coefficients proved in the euclidean algorithm.
add_uint(invert_prior, invert_next, uint64_count, invert_next);
}
else
{
// If both sides of add have opposite sign, then subtract
// and check for overflow.
uint64_t borrow = sub_uint(invert_prior, invert_next, uint64_count, invert_next);
if (borrow == 0)
{
// No borrow means |invert_prior| >= |invert_next|,
// so sign is same as invert_prior.
invert_next_positive = invert_prior_positive;
}
else
{
// Borrow means |invert prior| < |invert_next|,
// so sign is opposite of invert_prior.
invert_next_positive = !invert_prior_positive;
negate_uint(invert_next, uint64_count, invert_next);
}
}
// Swap prior and curr, and then curr and next.
swap(invert_prior, invert_curr);
swap(invert_prior_positive, invert_curr_positive);
swap(invert_curr, invert_next);
swap(invert_curr_positive, invert_next_positive);
// Swap numerator and denominator using pointer swings.
swap(numerator, denominator);
swap(numerator_bits, denominator_bits);
}
if (!is_equal_uint(denominator, uint64_count, 1))
{
// GCD is not one, so unable to find inverse.
return false;
}
// Correct coefficient if negative by modulo.
if (!invert_curr_positive && !is_zero_uint(invert_curr, uint64_count))
{
sub_uint(modulus, invert_curr, uint64_count, invert_curr);
invert_curr_positive = true;
}
// Set result.
set_uint(invert_curr, uint64_count, result);
return true;
}