sql_utils/public/functions/arithmetics.h (556 lines of code) (raw):

/* * Copyright 2023 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_PUBLIC_FUNCTIONS_ARITHMETICS_H_ #define THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_PUBLIC_FUNCTIONS_ARITHMETICS_H_ #include <cmath> #include <cstdint> #include <limits> #include <type_traits> #include "sql_utils/public/functions/arithmetics_internal.h" #include "sql_utils/public/functions/convert.h" #include "sql_utils/public/functions/util.h" #include "sql_utils/public/numeric_value.h" #include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "sql_utils/base/status.h" #ifndef __has_builtin #define __has_builtin(x) 0 #endif namespace bigquery_ml_utils { namespace functions { template <typename T> inline bool Add(T in1, T in2, T *out, absl::Status* error); template <typename InType, typename OutType = InType> inline bool Subtract(InType in1, InType in2, OutType* out, absl::Status* error); template <typename T> inline bool Multiply(T in1, T in2, T *out, absl::Status* error); template <typename T> inline bool Divide(T in1, T in2, T *out, absl::Status* error); template <typename T> inline bool Modulo(T in1, T in2, T *out, absl::Status* error); template <typename InType, typename OutType = InType> inline bool UnaryMinus(InType in, OutType *out, absl::Status* error); // Division function for NUMERIC/BIGNUMERICs with integer semantics. template <typename T> inline bool DivideToIntegralValue(T in1, T in2, T* out, absl::Status* error); // ----------------------- Internal parts ----------------------- // These are implementation details. Do not use outside of this file. namespace internal { template <typename T> inline bool CheckFloatOverflow(T in1, T in2, absl::string_view operator_symbol, T out, absl::Status* error) { if (ABSL_PREDICT_TRUE(std::isfinite(out))) { return true; } else if (!std::isfinite(in1) || !std::isfinite(in2)) { return true; } else { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, operator_symbol)); } } // Checks if range of representable values of type From lies entirely within // the range of representable values of type To. If this is true then From can // be casted to To without overflows. template <typename To, typename From> struct is_safe_to_cast { static_assert(std::numeric_limits<To>::is_integer && std::numeric_limits<From>::is_integer, "is_safe_to_cast can only be used with integer types"); static constexpr bool value = (std::numeric_limits<To>::digits >= std::numeric_limits<From>::digits) && (std::numeric_limits<To>::is_signed >= std::numeric_limits<From>::is_signed); }; template <typename To, typename From> inline To safe_cast(From in) { static_assert(is_safe_to_cast<To, From>::value, "This cast is not safe."); return static_cast<To>(in); } } // namespace internal // ----------------------- Floating point ----------------------- template <> inline bool Add(double in1, double in2, double *out, absl::Status* error) { *out = in1 + in2; return internal::CheckFloatOverflow(in1, in2, " + ", *out, error); } template <> inline bool Subtract(double in1, double in2, double *out, absl::Status* error) { *out = in1 - in2; return internal::CheckFloatOverflow(in1, in2, " - ", *out, error); } template <> inline bool Multiply(double in1, double in2, double *out, absl::Status* error) { *out = in1 * in2; return internal::CheckFloatOverflow(in1, in2, " * ", *out, error); } template <> inline bool Divide(double in1, double in2, double *out, absl::Status* error) { if (ABSL_PREDICT_FALSE(in2 == 0)) { return internal::UpdateError(error, internal::DivisionByZeroMessage(in1, in2)); } *out = in1 / in2; if (ABSL_PREDICT_TRUE(std::isfinite(*out))) { return true; } else if (!std::isfinite(in1) || !std::isfinite(in2)) { return true; } else { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " / ")); } } template <> inline bool UnaryMinus(double in, double *out, absl::Status* error) { *out = -in; return true; } template <> inline bool Add(long double in1, long double in2, long double* out, absl::Status* error) { *out = in1 + in2; return internal::CheckFloatOverflow(in1, in2, " + ", *out, error); } template <> inline bool Subtract(long double in1, long double in2, long double* out, absl::Status* error) { *out = in1 - in2; return internal::CheckFloatOverflow(in1, in2, " - ", *out, error); } template <> inline bool Multiply(long double in1, long double in2, long double* out, absl::Status* error) { *out = in1 * in2; return internal::CheckFloatOverflow(in1, in2, " * ", *out, error); } template <> inline bool Divide(long double in1, long double in2, long double* out, absl::Status* error) { if (ABSL_PREDICT_FALSE(in2 == 0)) { return internal::UpdateError(error, internal::DivisionByZeroMessage(in1, in2)); } *out = in1 / in2; if (ABSL_PREDICT_TRUE(std::isfinite(*out))) { return true; } else if (!std::isfinite(in1) || !std::isfinite(in2)) { return true; } else { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " / ")); } } template <> inline bool UnaryMinus(long double in, long double* out, absl::Status* error) { *out = -in; return true; } template <> inline bool Add(float in1, float in2, float* out, absl::Status* error) { *out = in1 + in2; return internal::CheckFloatOverflow(in1, in2, " + ", *out, error); } template <> inline bool Subtract(float in1, float in2, float *out, absl::Status* error) { *out = in1 - in2; return internal::CheckFloatOverflow(in1, in2, " - ", *out, error); } template <> inline bool Multiply(float in1, float in2, float *out, absl::Status* error) { *out = in1 * in2; return internal::CheckFloatOverflow(in1, in2, " * ", *out, error); } template <> inline bool UnaryMinus(float in, float *out, absl::Status* error) { *out = -in; return true; } // ----------------------- Integer ----------------------- // LLVM and Clang have builtin functions that implement overflow checking most // efficiently (using overflow flag). We use them when they are available. // http://clang.llvm.org/docs/LanguageExtensions.html#builtin-functions // Even when this builtins are available we may still need more generic code // below when input types are not the same. #if __has_builtin(__builtin_uadd_overflow) static_assert(std::is_same<uint32_t, unsigned>::value, // NOLINT(runtime/int) "unsigned != uint32_t?"); static_assert(std::is_same<int32_t, int>::value, "int != int32_t?"); // 64-bit integers may be either 'long' or 'long long' depending on the system, // but we have to explicitly specify the underlying type in the name of the // built-in function. We store the result in an intermediate value to work // around that. template <> inline bool Add<uint64_t>(uint64_t in1, uint64_t in2, uint64_t* out, absl::Status* error) { unsigned long long result; // NOLINT(runtime/int) bool has_overflow = __builtin_uaddll_overflow(in1, in2, &result); *out = static_cast<uint64_t>(result); if (ABSL_PREDICT_FALSE(has_overflow)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " + ")); } else { return true; } } template <> inline bool Multiply<uint64_t>(uint64_t in1, uint64_t in2, uint64_t* out, absl::Status* error) { unsigned long long result; // NOLINT(runtime/int) bool has_overflow = __builtin_umulll_overflow(in1, in2, &result); *out = static_cast<uint64_t>(result); if (ABSL_PREDICT_FALSE(has_overflow)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " * ")); } else { return true; } } template <> inline bool Add<int32_t>(int32_t in1, int32_t in2, int32_t* out, absl::Status* error) { if (ABSL_PREDICT_FALSE(__builtin_sadd_overflow(in1, in2, out))) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " + ")); } else { return true; } } template <> inline bool Add<int64_t>(int64_t in1, int64_t in2, int64_t* out, absl::Status* error) { long long result; // NOLINT(runtime/int) bool has_overflow = __builtin_saddll_overflow(in1, in2, &result); *out = static_cast<int64_t>(result); if (ABSL_PREDICT_FALSE(has_overflow)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " + ")); } else { return true; } } template <> inline bool Subtract<int64_t>(int64_t in1, int64_t in2, int64_t* out, absl::Status* error) { long long result; // NOLINT(runtime/int) bool has_overflow = __builtin_ssubll_overflow(in1, in2, &result); *out = static_cast<int64_t>(result); if (ABSL_PREDICT_FALSE(has_overflow)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " - ")); } else { return true; } } template <> inline bool Multiply<int32_t>(int32_t in1, int32_t in2, int32_t* out, absl::Status* error) { if (ABSL_PREDICT_FALSE(__builtin_smul_overflow(in1, in2, out))) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " * ")); } else { return true; } } template <> inline bool Multiply<int64_t>(int64_t in1, int64_t in2, int64_t* out, absl::Status* error) { long long result; // NOLINT(runtime/int) bool has_overflow = __builtin_smulll_overflow(in1, in2, &result); *out = static_cast<int64_t>(result); if (ABSL_PREDICT_FALSE(has_overflow)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " * ")); } else { return true; } } #else namespace arithmetics_internal { template <typename T> inline bool CheckSaturatedOverflow(T in1, T in2, absl::string_view operator_symbol, Saturated<T> result, T* out, absl::Status* error) { *out = result.Value(); if (ABSL_PREDICT_TRUE(result.IsValid())) { return true; } else { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, operator_symbol)); } } } // namespace arithmetics_internal template <typename T> bool Add(T in1, T in2, T* out, absl::Status* error) { arithmetics_internal::Saturated<T> result(internal::safe_cast<T>(in1)); result.Add(internal::safe_cast<T>(in2)); *out = result.Value(); return arithmetics_internal::CheckSaturatedOverflow(in1, in2, " + ", result, out, error); } template <> inline bool Subtract<int64_t>(int64_t in1, int64_t in2, int64_t* out, absl::Status* error) { arithmetics_internal::Saturated<int64_t> result( internal::safe_cast<int64_t>(in1)); result.Sub(internal::safe_cast<int64_t>(in2)); *out = result.Value(); return arithmetics_internal::CheckSaturatedOverflow(in1, in2, " - ", result, out, error); } template <typename T> inline bool Multiply(T in1, T in2, T *out, absl::Status* error) { arithmetics_internal::Saturated<T> result(internal::safe_cast<T>(in1)); result.Mul(internal::safe_cast<T>(in2)); *out = result.Value(); return arithmetics_internal::CheckSaturatedOverflow(in1, in2, " * ", result, out, error); } #endif template <typename T> inline bool Modulo(T in1, T in2, T *out, absl::Status* error) { static_assert( std::is_same<uint64_t, T>::value || std::is_same<int64_t, T>::value, "Modulo only supports 64 bit integer"); if (ABSL_PREDICT_FALSE(in2 == 0)) { return internal::UpdateError( error, absl::StrCat("division by zero: MOD(", in1, ", ", in2, ")")); } if constexpr (std::is_same_v<int64_t, T>) { if (ABSL_PREDICT_FALSE(in2 == -1)) { // Workaround for -9223372035808 % -1 triggering floating point exception. *out = 0; return true; } } *out = in1 % in2; return true; } static_assert(std::numeric_limits<int32_t>::min() + std::numeric_limits<int32_t>::max() == -1, "int32 is not a two's complement type?"); static_assert(std::numeric_limits<int64_t>::min() + std::numeric_limits<int64_t>::max() == -1, "int64 is not a two's complement type?"); template <> inline bool UnaryMinus<int32_t, int32_t>(int32_t in, int32_t* out, absl::Status* error) { if (in == std::numeric_limits<int32_t>::min()) { return internal::UpdateError(error, internal::UnaryOverflowMessage(in, "-")); } *out = -in; return true; } template <> inline bool UnaryMinus<int64_t, int64_t>(int64_t in, int64_t* out, absl::Status* error) { if (in == std::numeric_limits<int64_t>::min()) { return internal::UpdateError(error, internal::UnaryOverflowMessage(in, "-")); } *out = -in; return true; } template <> inline bool Subtract<uint64_t, int64_t>(uint64_t in1, uint64_t in2, int64_t* out, absl::Status* error) { if (in1 >= in2) { uint64_t tmp = in1 - in2; if (!Convert<uint64_t, int64_t>(tmp, out, nullptr)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " - ")); } return true; } uint64_t tmp = in2 - in1; if (ABSL_PREDICT_FALSE( tmp == 1ull + static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))) { *out = std::numeric_limits<int64_t>::min(); return true; } if (!Convert<uint64_t, int64_t>(tmp, out, nullptr) || !UnaryMinus<int64_t, int64_t>(*out, out, error)) { return internal::UpdateError( error, internal::BinaryOverflowMessage(in1, in2, " - ")); } return true; } template <> inline bool Subtract<uint64_t, uint64_t>(uint64_t in1, uint64_t in2, uint64_t* out, absl::Status* error) { return internal::UpdateError(error, "invalid UINT64 subtraction signature"); } template <> inline bool Divide(uint64_t in1, uint64_t in2, uint64_t* out, absl::Status* error) { if (ABSL_PREDICT_FALSE(in2 == 0)) { return internal::UpdateError(error, internal::DivisionByZeroMessage(in1, in2)); } *out = in1 / in2; return true; } template <> inline bool Divide(int64_t in1, int64_t in2, int64_t* out, absl::Status* error) { if (ABSL_PREDICT_FALSE(in2 == -1)) { return UnaryMinus(in1, out, error); } if (ABSL_PREDICT_FALSE(in2 == 0)) { return internal::UpdateError(error, internal::DivisionByZeroMessage(in1, in2)); } *out = in1 / in2; return true; } // ----------------------- Numeric ----------------------- template <> inline bool Add(NumericValue in1, NumericValue in2, NumericValue* out, absl::Status* error) { absl::StatusOr<NumericValue> numeric_status = in1.Add(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } template <> inline bool Subtract(NumericValue in1, NumericValue in2, NumericValue* out, absl::Status* error) { absl::StatusOr<NumericValue> numeric_status = in1.Subtract(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } template <> inline bool Multiply(NumericValue in1, NumericValue in2, NumericValue* out, absl::Status* error) { absl::StatusOr<NumericValue> numeric_status = in1.Multiply(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } template <> inline bool Divide(NumericValue in1, NumericValue in2, NumericValue* out, absl::Status* error) { absl::StatusOr<NumericValue> numeric_status = in1.Divide(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } template <> inline bool UnaryMinus(NumericValue in, NumericValue* out, absl::Status* error) { *out = in.Negate(); return true; } template <> inline bool Modulo(NumericValue in1, NumericValue in2, NumericValue *out, absl::Status* error) { absl::StatusOr<NumericValue> numeric_status = in1.Mod(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } // ----------------------- BIGNUMERIC ----------------------- template <> inline bool Add(BigNumericValue in1, BigNumericValue in2, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in1.Add(in2); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = *bignumeric_status; return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } template <> inline bool Subtract(BigNumericValue in1, BigNumericValue in2, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in1.Subtract(in2); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = *bignumeric_status; return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } template <> inline bool Multiply(BigNumericValue in1, BigNumericValue in2, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in1.Multiply(in2); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = *bignumeric_status; return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } template <> inline bool Divide(BigNumericValue in1, BigNumericValue in2, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in1.Divide(in2); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = *bignumeric_status; return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } template <> inline bool UnaryMinus(BigNumericValue in, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in.Negate(); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = *bignumeric_status; return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } template <> inline bool Modulo(BigNumericValue in1, BigNumericValue in2, BigNumericValue* out, absl::Status* error) { absl::StatusOr<BigNumericValue> bignumeric_status = in1.Mod(in2); if (ABSL_PREDICT_TRUE(bignumeric_status.ok())) { *out = bignumeric_status.value(); return true; } if (error != nullptr) { *error = bignumeric_status.status(); } return false; } // ----------------------- NUMERIC/BIGNUMERIC ----------------------- template <typename T> inline bool DivideToIntegralValue(T in1, T in2, T* out, absl::Status* error) { absl::StatusOr<T> numeric_status = in1.DivideToIntegralValue(in2); if (ABSL_PREDICT_TRUE(numeric_status.ok())) { *out = numeric_status.value(); return true; } if (error != nullptr) { *error = numeric_status.status(); } return false; } } // namespace functions } // namespace bigquery_ml_utils #endif // THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_PUBLIC_FUNCTIONS_ARITHMETICS_H_