sql_utils/public/functions/convert_internal.h (65 lines of code) (raw):
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <limits>
#include "sql_utils/base/logging.h"
#include <cstdint>
namespace bigquery_ml_utils {
// Do not use any methods from the convert_internal namespace.
namespace convert_internal {
// Return true if the value is in the representable range of the result type:
// MIN <= value <= MAX.
template <typename FloatType, typename ResultType>
bool InRangeNoTruncate(FloatType value);
// Return true if the truncated form of value is smaller than or equal to the
// MAX value of IntType. When the MAX value of IntType can not be represented
// precisely in FloatType, the comparison is tricky, because the MAX value of
// IntType is promoted to a FloatType value that is actually greater than what
// IntType can handle. Also note that when value is nan, this function will
// return false.
template<typename FloatType, typename IntType>
bool SmallerThanOrEqualToIntMax(FloatType value) {
if (value <= 0) {
return true;
}
if (std::isnan(value) || std::isinf(value)) {
return false;
}
// Set exp such that value == f * 2^exp for some f with |f| in [0.5, 1.0),
// unless value is zero in which case exp == 0. Note that this implies that
// the magnitude of value is strictly less than 2^exp.
int exp = 0;
std::frexp(value, &exp);
// Let N be the number of non-sign bits in the representation of IntType.
// If the magnitude of value is strictly less than 2^N, the truncated version
// of value is representable as IntType.
static_assert(std::numeric_limits<FloatType>::radix == 2,
"return type size must be based on 2");
return exp <= std::numeric_limits<IntType>::digits;
}
// Return true if max value of IntType can be represented precisely in
// FloatType.
template<typename FloatType, typename IntType>
bool CanRepresentMaxPrecisely() {
int fraction_bits;
if (sizeof(FloatType) == 4) {
fraction_bits = 24;
} else if (sizeof(FloatType) == 8) {
fraction_bits = 53;
} else if (sizeof(FloatType) == 10) {
fraction_bits = 64;
} else if (sizeof(FloatType) == 16) {
fraction_bits = 113;
} else {
SQL_LOG(FATAL) << "FloatType is not supported";
}
return fraction_bits >= std::numeric_limits<IntType>::digits;
}
template <typename FloatType, typename ResultType>
bool InRangeNoTruncate(FloatType value) {
static_assert(std::is_floating_point<FloatType>::value,
"value must have floating point type");
static_assert(std::is_integral<ResultType>::value,
"return value must have integral type");
static_assert(sizeof(ResultType) <= 16,
"ResultType is no larger than 128 bits");
static_assert(sizeof(FloatType) >= 4,
"FloatType is no smaller than IEEE754 single precision");
if (std::isnan(value)) {
return false;
}
// Return false for unsigned type and negative value.
if (!std::is_signed<ResultType>::value && value < 0) {
return false;
}
FloatType lower_bound = std::numeric_limits<ResultType>::min();
if (CanRepresentMaxPrecisely<FloatType, ResultType>()) {
FloatType upper_bound = std::numeric_limits<ResultType>::max();
return (value <= upper_bound) && (value >= lower_bound);
} else {
// If Max (2^N-1) value can't be represented precisely in FloatType,
// -2^N-1 can't be represented precisely either.
// Min (-2^N or 0) can anyway be represented precisely in FloatType.
return value >= lower_bound &&
SmallerThanOrEqualToIntMax<FloatType, ResultType>(value);
}
}
} // namespace convert_internal
} // namespace bigquery_ml_utils