SEALContext::ContextData SEALContext::validate()

in native/src/seal/context.cpp [135:420]


    SEALContext::ContextData SEALContext::validate(EncryptionParameters parms)
    {
        ContextData context_data(parms, pool_);
        context_data.qualifiers_.parameter_error = error_type::success;

        // Is a scheme set?
        if (parms.scheme() == scheme_type::none)
        {
            context_data.qualifiers_.parameter_error = error_type::invalid_scheme;
            return context_data;
        }

        auto &coeff_modulus = parms.coeff_modulus();
        auto &plain_modulus = parms.plain_modulus();

        // The number of coeff moduli is restricted to 64 to prevent unexpected behaviors
        if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN)
        {
            context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_size;
            return context_data;
        }

        size_t coeff_modulus_size = coeff_modulus.size();
        for (size_t i = 0; i < coeff_modulus_size; i++)
        {
            // Check coefficient moduli bounds
            if (coeff_modulus[i].value() >> SEAL_USER_MOD_BIT_COUNT_MAX ||
                !(coeff_modulus[i].value() >> (SEAL_USER_MOD_BIT_COUNT_MIN - 1)))
            {
                context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_bit_count;
                return context_data;
            }
        }

        // Compute the product of all coeff moduli
        context_data.total_coeff_modulus_ = allocate_uint(coeff_modulus_size, pool_);
        auto coeff_modulus_values(allocate_uint(coeff_modulus_size, pool_));
        for (size_t i = 0; i < coeff_modulus_size; i++)
        {
            coeff_modulus_values[i] = coeff_modulus[i].value();
        }
        multiply_many_uint64(
            coeff_modulus_values.get(), coeff_modulus_size, context_data.total_coeff_modulus_.get(), pool_);
        context_data.total_coeff_modulus_bit_count_ =
            get_significant_bit_count_uint(context_data.total_coeff_modulus_.get(), coeff_modulus_size);

        // Check polynomial modulus degree and create poly_modulus
        size_t poly_modulus_degree = parms.poly_modulus_degree();
        if (poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX)
        {
            // Parameters are not valid
            context_data.qualifiers_.parameter_error = error_type::invalid_poly_modulus_degree;
            return context_data;
        }
        int coeff_count_power = get_power_of_two(poly_modulus_degree);
        if (coeff_count_power < 0)
        {
            // Parameters are not valid
            context_data.qualifiers_.parameter_error = error_type::invalid_poly_modulus_degree_non_power_of_two;
            return context_data;
        }

        // Quick sanity check
        if (!product_fits_in(coeff_modulus_size, poly_modulus_degree))
        {
            context_data.qualifiers_.parameter_error = error_type::invalid_parameters_too_large;
            return context_data;
        }

        // Polynomial modulus X^(2^k) + 1 is guaranteed at this point
        context_data.qualifiers_.using_fft = true;

        // Assume parameters satisfy desired security level
        context_data.qualifiers_.sec_level = sec_level_;

        // Check if the parameters are secure according to HomomorphicEncryption.org security standard
        if (context_data.total_coeff_modulus_bit_count_ > CoeffModulus::MaxBitCount(poly_modulus_degree, sec_level_))
        {
            // Not secure according to HomomorphicEncryption.org security standard
            context_data.qualifiers_.sec_level = sec_level_type::none;
            if (sec_level_ != sec_level_type::none)
            {
                // Parameters are not valid
                context_data.qualifiers_.parameter_error = error_type::invalid_parameters_insecure;
                return context_data;
            }
        }

        // Set up RNSBase for coeff_modulus
        // RNSBase's constructor may fail due to:
        //   (1) coeff_mod not coprime
        //   (2) cannot find inverse of punctured products (because of (1))
        Pointer<RNSBase> coeff_modulus_base;
        try
        {
            coeff_modulus_base = allocate<RNSBase>(pool_, coeff_modulus, pool_);
        }
        catch (const invalid_argument &)
        {
            // Parameters are not valid
            context_data.qualifiers_.parameter_error = error_type::failed_creating_rns_base;
            return context_data;
        }

        // Can we use NTT with coeff_modulus?
        context_data.qualifiers_.using_ntt = true;
        try
        {
            CreateNTTTables(coeff_count_power, coeff_modulus, context_data.small_ntt_tables_, pool_);
        }
        catch (const invalid_argument &)
        {
            context_data.qualifiers_.using_ntt = false;
            // Parameters are not valid
            context_data.qualifiers_.parameter_error = error_type::invalid_coeff_modulus_no_ntt;
            return context_data;
        }

        if (parms.scheme() == scheme_type::bfv)
        {
            // Plain modulus must be at least 2 and at most 60 bits
            if (plain_modulus.value() >> SEAL_PLAIN_MOD_BIT_COUNT_MAX ||
                !(plain_modulus.value() >> (SEAL_PLAIN_MOD_BIT_COUNT_MIN - 1)))
            {
                context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_bit_count;
                return context_data;
            }

            // Check that all coeff moduli are relatively prime to plain_modulus
            for (size_t i = 0; i < coeff_modulus_size; i++)
            {
                if (!are_coprime(coeff_modulus[i].value(), plain_modulus.value()))
                {
                    context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_coprimality;
                    return context_data;
                }
            }

            // Check that plain_modulus is smaller than total coeff modulus
            if (!is_less_than_uint(
                    plain_modulus.data(), plain_modulus.uint64_count(), context_data.total_coeff_modulus_.get(),
                    coeff_modulus_size))
            {
                // Parameters are not valid
                context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_too_large;
                return context_data;
            }

            // Can we use batching? (NTT with plain_modulus)
            context_data.qualifiers_.using_batching = true;
            try
            {
                CreateNTTTables(coeff_count_power, { plain_modulus }, context_data.plain_ntt_tables_, pool_);
            }
            catch (const invalid_argument &)
            {
                context_data.qualifiers_.using_batching = false;
            }

            // Check for plain_lift
            // If all the small coefficient moduli are larger than plain modulus, we can quickly
            // lift plain coefficients to RNS form
            context_data.qualifiers_.using_fast_plain_lift = true;
            for (size_t i = 0; i < coeff_modulus_size; i++)
            {
                context_data.qualifiers_.using_fast_plain_lift &= (coeff_modulus[i].value() > plain_modulus.value());
            }

            // Calculate coeff_div_plain_modulus (BFV-"Delta") and the remainder upper_half_increment
            auto temp_coeff_div_plain_modulus = allocate_uint(coeff_modulus_size, pool_);
            context_data.coeff_div_plain_modulus_ = allocate<MultiplyUIntModOperand>(coeff_modulus_size, pool_);
            context_data.upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
            auto wide_plain_modulus(duplicate_uint_if_needed(
                plain_modulus.data(), plain_modulus.uint64_count(), coeff_modulus_size, false, pool_));
            divide_uint(
                context_data.total_coeff_modulus_.get(), wide_plain_modulus.get(), coeff_modulus_size,
                temp_coeff_div_plain_modulus.get(), context_data.upper_half_increment_.get(), pool_);

            // Store the non-RNS form of upper_half_increment for BFV encryption
            context_data.coeff_modulus_mod_plain_modulus_ = context_data.upper_half_increment_[0];

            // Decompose coeff_div_plain_modulus into RNS factors
            coeff_modulus_base->decompose(temp_coeff_div_plain_modulus.get(), pool_);

            for (size_t i = 0; i < coeff_modulus_size; i++)
            {
                context_data.coeff_div_plain_modulus_[i].set(
                    temp_coeff_div_plain_modulus[i], coeff_modulus_base->base()[i]);
            }

            // Decompose upper_half_increment into RNS factors
            coeff_modulus_base->decompose(context_data.upper_half_increment_.get(), pool_);

            // Calculate (plain_modulus + 1) / 2.
            context_data.plain_upper_half_threshold_ = (plain_modulus.value() + 1) >> 1;

            // Calculate coeff_modulus - plain_modulus.
            context_data.plain_upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
            if (context_data.qualifiers_.using_fast_plain_lift)
            {
                // Calculate coeff_modulus[i] - plain_modulus if using_fast_plain_lift
                for (size_t i = 0; i < coeff_modulus_size; i++)
                {
                    context_data.plain_upper_half_increment_[i] = coeff_modulus[i].value() - plain_modulus.value();
                }
            }
            else
            {
                sub_uint(
                    context_data.total_coeff_modulus(), wide_plain_modulus.get(), coeff_modulus_size,
                    context_data.plain_upper_half_increment_.get());
            }
        }
        else if (parms.scheme() == scheme_type::ckks)
        {
            // Check that plain_modulus is set to zero
            if (!plain_modulus.is_zero())
            {
                // Parameters are not valid
                context_data.qualifiers_.parameter_error = error_type::invalid_plain_modulus_nonzero;
                return context_data;
            }

            // When using CKKS batching (BatchEncoder) is always enabled
            context_data.qualifiers_.using_batching = true;

            // Cannot use fast_plain_lift for CKKS since the plaintext coefficients
            // can easily be larger than coefficient moduli
            context_data.qualifiers_.using_fast_plain_lift = false;

            // Calculate 2^64 / 2 (most negative plaintext coefficient value)
            context_data.plain_upper_half_threshold_ = uint64_t(1) << 63;

            // Calculate plain_upper_half_increment = 2^64 mod coeff_modulus for CKKS plaintexts
            context_data.plain_upper_half_increment_ = allocate_uint(coeff_modulus_size, pool_);
            for (size_t i = 0; i < coeff_modulus_size; i++)
            {
                uint64_t tmp = barrett_reduce_64(uint64_t(1) << 63, coeff_modulus[i]);
                context_data.plain_upper_half_increment_[i] =
                    multiply_uint_mod(tmp, sub_safe(coeff_modulus[i].value(), uint64_t(2)), coeff_modulus[i]);
            }

            // Compute the upper_half_threshold for this modulus.
            context_data.upper_half_threshold_ = allocate_uint(coeff_modulus_size, pool_);
            increment_uint(
                context_data.total_coeff_modulus(), coeff_modulus_size, context_data.upper_half_threshold_.get());
            right_shift_uint(
                context_data.upper_half_threshold_.get(), 1, coeff_modulus_size,
                context_data.upper_half_threshold_.get());
        }
        else
        {
            context_data.qualifiers_.parameter_error = error_type::invalid_scheme;
            return context_data;
        }

        // Create RNSTool
        // RNSTool's constructor may fail due to:
        //   (1) auxiliary base being too large
        //   (2) cannot find inverse of punctured products in auxiliary base
        try
        {
            context_data.rns_tool_ =
                allocate<RNSTool>(pool_, poly_modulus_degree, *coeff_modulus_base, plain_modulus, pool_);
        }
        catch (const exception &)
        {
            // Parameters are not valid
            context_data.qualifiers_.parameter_error = error_type::failed_creating_rns_tool;
            return context_data;
        }

        // Check whether the coefficient modulus consists of a set of primes that are in decreasing order
        context_data.qualifiers_.using_descending_modulus_chain = true;
        for (size_t i = 0; i < coeff_modulus_size - 1; i++)
        {
            context_data.qualifiers_.using_descending_modulus_chain &=
                (coeff_modulus[i].value() > coeff_modulus[i + 1].value());
        }

        // Create GaloisTool
        context_data.galois_tool_ = allocate<GaloisTool>(pool_, coeff_count_power, pool_);

        // Done with validation and pre-computations
        return context_data;
    }