sql_utils/base/mathutil.h (72 lines of code) (raw):
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_MATHUTIL_H_
#define THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_MATHUTIL_H_
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "absl/base/attributes.h"
#include "sql_utils/base/logging.h"
namespace bigquery_ml_utils_base {
class MathUtil {
public:
template<typename IntegralType>
static IntegralType FloorOfRatio(IntegralType numerator,
IntegralType denominator) {
return CeilOrFloorOfRatio<IntegralType, false>(numerator, denominator);
}
template<typename IntegralType, bool ceil>
static IntegralType CeilOrFloorOfRatio(IntegralType numerator,
IntegralType denominator);
// Returns the nonnegative remainder when one integer is divided by another.
// The modulus must be positive. Use integral types only (no float or
// double).
template <class T>
static T NonnegativeMod(T a, T b) {
static_assert(std::is_integral<T>::value, "Integral types only.");
SQL_DCHECK_GT(b, 0);
// As of C++11 (per [expr.mul]/4), a%b is in (-b,0] for a<0, b>0.
T c = a % b;
return c + (c < 0) * b;
}
// Returns the minimum integer value which is a multiple of rounding_value,
// and greater than or equal to input_value.
// The input_value must be greater than or equal to zero, and the
// rounding_value must be greater than zero.
template <typename IntType>
static IntType RoundUpTo(IntType input_value, IntType rounding_value) {
static_assert(std::numeric_limits<IntType>::is_integer,
"RoundUpTo() operation type is not integer");
SQL_DCHECK_GE(input_value, 0);
SQL_DCHECK_GT(rounding_value, 0);
const IntType remainder = input_value % rounding_value;
return (remainder == 0) ? input_value
: (input_value - remainder + rounding_value);
}
// Decomposes `value` to the form `mantissa * pow(2, exponent)`. Similar to
// `std::frexp`, but returns `mantissa` as an integer without normalization.
//
// The returned `mantissa` might be a power of 2; this method does not shift
// the trailing 0 bits out.
//
// If `value` is inf, then `mantissa = kint64max` and `exponent = intmax`.
// If `value` is -inf, then `mantissa = -kint64max` and `exponent = intmax`.
// If `value` is NaN, then `mantissa = 0` and `exponent = intmax`.
// If `value` is 0, then `mantissa = 0` and `exponent < 0`.
//
// For all cases, `value` is equivalent to
// `static_cast<double>(mantissa) * std::ldexp(1.0, exponent)`, though the
// bits might differ (e.g., `-0.0` vs `0.0`, signaling NaN vs quiet NaN).
//
// For all cases except NaN,
// `value = std::ldexp(static_cast<double>(mantissa), exponent)`.
struct FloatParts {
int32_t mantissa;
int exponent;
};
static FloatParts Decompose(float value);
struct DoubleParts {
int64_t mantissa;
int exponent;
};
static DoubleParts Decompose(double value);
private:
// Wraps `x` to the periodic range `[low, high)`
static double Wrap(double x, double low, double high);
};
// ---- CeilOrFloorOfRatio ----
// This is a branching-free, cast-to-double-free implementation.
//
// Casting to double is in general incorrect because of loss of precision
// when casting an int64_t into a double.
//
// There's a bunch of 'recipes' to compute a integer ceil (or floor) on the web,
// and most of them are incorrect.
template<typename IntegralType, bool ceil>
IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
IntegralType denominator) {
static_assert(std::numeric_limits<IntegralType>::is_integer,
"CeilOfRatio is only defined for integral types");
SQL_DCHECK_NE(0, denominator) << "Division by zero is not supported.";
SQL_DCHECK(!std::numeric_limits<IntegralType>::is_signed ||
numerator != std::numeric_limits<IntegralType>::lowest() ||
denominator != -1)
<< "Dividing " << numerator << "by -1 is not supported: it would SIGFPE";
const IntegralType rounded_toward_zero = numerator / denominator;
const bool needs_round = (numerator % denominator) != 0;
// It is important to use >= here, even for the denominator, to ensure that
// this value is a compile-time constant for unsigned types.
const bool same_sign = (numerator >= 0) == (denominator >= 0);
if (ceil) { // Compile-time condition: not an actual branching
return rounded_toward_zero +
static_cast<IntegralType>(same_sign && needs_round);
} else {
return rounded_toward_zero -
static_cast<IntegralType>(!same_sign && needs_round);
}
}
} // namespace bigquery_ml_utils_base
#endif // THIRD_PARTY_PY_BIGQUERY_ML_UTILS_SQL_UTILS_BASE_MATHUTIL_H_