modules/platforms/cpp/ignite/common/detail/mpi.cpp (216 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "mpi.h" #include "ignite_error.h" #include <mbedtls/bignum.h> #include <utility> #include <vector> static_assert(std::is_same_v<mbedtls_mpi_uint, std::uint32_t>, "MbedTLS word should be std::uint32_t."); namespace ignite::detail { namespace { void check(int code) { switch (code) { case MBEDTLS_ERR_MPI_ALLOC_FAILED: throw ignite_error("mbedtls: alloc failed"); case MBEDTLS_ERR_MPI_BAD_INPUT_DATA: throw ignite_error("mbedtls: bad input data"); case MBEDTLS_ERR_MPI_BUFFER_TOO_SMALL: throw ignite_error("mbedtls: buffer too small"); case MBEDTLS_ERR_MPI_DIVISION_BY_ZERO: throw ignite_error("mbedtls: division by zero"); case MBEDTLS_ERR_MPI_FILE_IO_ERROR: throw ignite_error("mbedtls: file io error"); case MBEDTLS_ERR_MPI_INVALID_CHARACTER: throw ignite_error("mbedtls: invalid characters"); case MBEDTLS_ERR_MPI_NEGATIVE_VALUE: throw ignite_error("mbedtls: negative value"); case MBEDTLS_ERR_MPI_NOT_ACCEPTABLE: throw ignite_error("mbedtls: not acceptable"); case 0: return; default: throw ignite_error("mbedtls: unspecified error"); } } } // namespace mpi::mpi() { init(); } mpi::mpi(std::int32_t v) { init(); check(mbedtls_mpi_lset(val, v)); } mpi::mpi(const char *string) { init(); assign_from_string(string); } mpi::~mpi() { free(); } mpi::mpi(const mpi &other) { init(); check(mbedtls_mpi_copy(val, other.val)); } mpi::mpi(mpi &&other) noexcept { using std::swap; init(); std::swap(val->s, other.val->s); std::swap(val->n, other.val->n); std::swap(val->p, other.val->p); } mpi &mpi::operator=(const mpi &other) { if (this == &other) { return *this; } reinit(); check(mbedtls_mpi_copy(val, other.val)); return *this; } mpi &mpi::operator=(mpi &&other) noexcept { using std::swap; if (this == &other) { return *this; } std::swap(val, other.val); return *this; } void mpi::init() { val = new mbedtls_mpi; mbedtls_mpi_init(val); } void mpi::free() { mbedtls_mpi_free(val); delete val; } void mpi::reinit() { free(); init(); } mpi_sign mpi::sign() const noexcept { return static_cast<mpi_sign>(val->s); } mpi::word *mpi::pointer() const noexcept { return val->p; } unsigned short mpi::length() const noexcept { return val->n; } mpi::mag_view mpi::magnitude() const noexcept { return {val->p, val->n, mbedtls_mpi_size(val)}; } bool mpi::is_zero() const noexcept { return mbedtls_mpi_cmp_int(val, 0) == 0; } bool mpi::is_positive() const noexcept { return val->s > 0 && !is_zero(); } bool mpi::is_negative() const noexcept { return val->s < 0; } void mpi::set_sign(mpi_sign sign) { val->s = sign; } void mpi::make_positive() noexcept { val->s = mpi_sign::POSITIVE; } void mpi::make_negative() noexcept { val->s = mpi_sign::NEGATIVE; } void mpi::negate() noexcept { if (!is_zero()) { val->s = -val->s; } } void swap(mpi &lhs, mpi &rhs) { using std::swap; std::swap(lhs.val->s, rhs.val->s); std::swap(lhs.val->n, rhs.val->n); std::swap(lhs.val->p, rhs.val->p); } mpi mpi::operator+(const mpi &addendum) const { mpi result; check(mbedtls_mpi_add_mpi(result.val, val, addendum.val)); return result; } mpi mpi::operator-(const mpi &subtrahend) const { mpi result; check(mbedtls_mpi_sub_mpi(result.val, val, subtrahend.val)); return result; } mpi mpi::operator*(const mpi &factor) const { mpi result; check(mbedtls_mpi_mul_mpi(result.val, val, factor.val)); return result; } mpi mpi::operator/(const mpi &divisor) const { mpi result; check(mbedtls_mpi_div_mpi(result.val, nullptr, val, divisor.val)); return result; } mpi mpi::operator%(const mpi &divisor) const { mpi remainder; check(mbedtls_mpi_div_mpi(nullptr, remainder.val, val, divisor.val)); return remainder; } void mpi::add(const mpi &addendum) { check(mbedtls_mpi_add_mpi(val, val, addendum.val)); } void mpi::subtract(const mpi &subtrahend) { check(mbedtls_mpi_sub_mpi(val, val, subtrahend.val)); } void mpi::multiply(const mpi &factor) { check(mbedtls_mpi_mul_mpi(val, val, factor.val)); } void mpi::divide(const mpi &divisor) { check(mbedtls_mpi_div_mpi(val, nullptr, val, divisor.val)); } void mpi::modulo(const mpi &divisor) { check(mbedtls_mpi_div_mpi(nullptr, val, val, divisor.val)); } void mpi::shrink(size_t limbs) { check(mbedtls_mpi_shrink(val, limbs)); } void mpi::grow(size_t limbs) { check(mbedtls_mpi_grow(val, limbs)); } mpi mpi::div_and_mod(const mpi &divisor, mpi &remainder) const { mpi result; check(mbedtls_mpi_div_mpi(result.val, remainder.val, val, divisor.val)); return result; } void mpi::assign_from_string(const char *string) { reinit(); check(mbedtls_mpi_read_string(val, 10, string)); } std::string mpi::to_string() const { std::size_t required_size = 0; auto code = mbedtls_mpi_write_string(val, 10, nullptr, 0, &required_size); // get required buffer size if (code == MBEDTLS_ERR_MPI_BUFFER_TOO_SMALL) { std::string buffer(required_size, 0); check(mbedtls_mpi_write_string(val, 10, buffer.data(), required_size, &required_size)); buffer.resize(required_size - 1); // -1 for \0. We don't need it for std::string. return buffer; } check(code); return {}; } bool mpi::operator==(const mpi &other) const { return compare(other); } int mpi::compare(const mpi &other, bool ignore_sign) const noexcept { return ignore_sign ? mbedtls_mpi_cmp_abs(val, other.val) : mbedtls_mpi_cmp_mpi(val, other.val); } std::size_t mpi::magnitude_bit_length() const noexcept { return mbedtls_mpi_bitlen(val); } bool mpi::write(std::uint8_t *data, std::size_t size, bool big_endian) { if (big_endian) { return mbedtls_mpi_write_binary(val, data, size) == 0; } return mbedtls_mpi_write_binary_le(val, data, size) == 0; } bool mpi::read(const std::uint8_t *data, std::size_t size, bool big_endian) { if (big_endian) { return mbedtls_mpi_read_binary(val, data, size) == 0; } return mbedtls_mpi_read_binary_le(val, data, size) == 0; } } // namespace ignite::detail