in be/src/vec/functions/function_binary_arithmetic.h [227:705]
throw Exception(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR, \
"Arithmetic overflow: {} {} {} = {}, result type: {}", left_value, op_name, \
right_value, result_value, result_type_name)
/// Binary operations for Decimals need scale args
/// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2);
/// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
template <typename LeftDataType, typename RightDataType, typename ResultDataType,
template <typename, typename> typename Operation, typename Name, typename ResultType,
bool is_to_null_type, bool check_overflow>
struct DecimalBinaryOperation {
using A = typename LeftDataType::FieldType;
using B = typename RightDataType::FieldType;
using OpTraits = OperationTraits<Operation, A, B>;
using NativeResultType = typename NativeType<ResultType>::Type;
using Op = Operation<NativeResultType, NativeResultType>;
using Traits = NumberTraits::BinaryOperatorTraits<A, B>;
using ArrayC = typename ColumnDecimal<ResultType>::Container;
private:
template <typename T>
static int8_t sgn(const T& x) {
return (x > 0) ? 1 : ((x < 0) ? -1 : 0);
}
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, const LeftDataType& type_left,
const RightDataType& type_right, const ResultDataType& type_result,
size_t size, const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
if constexpr (OpTraits::is_multiply && IsDecimalV2<A> && IsDecimalV2<B> &&
IsDecimalV2<ResultType>) {
Op::template vector_vector<check_overflow>(a, b, c, size);
} else {
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
for (size_t i = 0; i < size; i++) {
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
a[i], b[i], type_left, type_right, type_result,
max_result_number, scale_diff_multiplier));
}
},
make_bool_variant(need_adjust_scale && check_overflow));
if (OpTraits::is_multiply && need_adjust_scale && !check_overflow) {
auto sig_uptr = std::unique_ptr<int8_t[]>(new int8_t[size]);
int8_t* sig = sig_uptr.get();
for (size_t i = 0; i < size; i++) {
sig[i] = sgn(c[i].value);
}
for (size_t i = 0; i < size; i++) {
c[i].value = (c[i].value - sig[i]) / scale_diff_multiplier.value + sig[i];
}
}
}
}
/// null_map for divide and mod
static void vector_vector(const typename Traits::ArrayA::value_type* __restrict a,
const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(
apply(a[i], b[i], null_map[i], max_result_number));
}
} else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
for (size_t i = 0; i < size; ++i) {
if constexpr (IsDecimalNumber<B> && IsDecimalNumber<A>) {
c[i] = typename ArrayC::value_type(
apply(a[i].value, b[i].value, null_map[i], max_result_number));
} else if constexpr (IsDecimalNumber<A>) {
c[i] = typename ArrayC::value_type(
apply(a[i].value, b[i], null_map[i], max_result_number));
} else {
c[i] = typename ArrayC::value_type(
apply(a[i], b[i].value, null_map[i], max_result_number));
}
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(
apply(a[i], b[i], null_map[i], max_result_number));
}
}
}
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
typename ArrayC::value_type* c, const LeftDataType& type_left,
const RightDataType& type_right, const ResultDataType& type_result,
size_t size, const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
static_assert(!OpTraits::is_division);
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
a[i], b, type_left, type_right, type_result, max_result_number,
scale_diff_multiplier));
}
},
make_bool_variant(need_adjust_scale));
}
static void vector_constant(const typename Traits::ArrayA::value_type* __restrict a, B b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(
apply(a[i], b.value, null_map[i], max_result_number));
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i], max_result_number));
}
}
}
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, const LeftDataType& type_left,
const RightDataType& type_right, const ResultDataType& type_result,
size_t size, const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
bool need_adjust_scale = scale_diff_multiplier.value > 1;
std::visit(
[&](auto need_adjust_scale) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply<need_adjust_scale>(
a, b[i], type_left, type_right, type_result, max_result_number,
scale_diff_multiplier));
}
},
make_bool_variant(need_adjust_scale));
}
static void constant_vector(A a, const typename Traits::ArrayB::value_type* __restrict b,
typename ArrayC::value_type* c, NullMap& null_map, size_t size,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(
apply(a, b[i].value, null_map[i], max_result_number));
}
} else {
for (size_t i = 0; i < size; ++i) {
c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i], max_result_number));
}
}
}
static ResultType constant_constant(A a, B b, const LeftDataType& type_left,
const RightDataType& type_right,
const ResultDataType& type_result,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
return ResultType(apply<true>(a, b, type_left, type_right, type_result, max_result_number,
scale_diff_multiplier));
}
static ResultType constant_constant(A a, B b, UInt8& is_null,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
if constexpr (IsDecimalNumber<A>) {
return ResultType(apply(a.value, b.value, is_null, max_result_number));
} else {
return ResultType(apply(a, b.value, is_null, max_result_number));
}
} else {
return ResultType(apply(a, b, is_null, max_result_number));
}
}
public:
static ColumnPtr adapt_decimal_constant_constant(A a, B b, const LeftDataType& type_left,
const RightDataType& type_right,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier,
DataTypePtr res_data_type) {
auto type_result =
assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type);
auto column_result = ColumnDecimal<ResultType>::create(
1, assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type)
.get_scale());
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"adapt_decimal_constant_constant Invalid function type!");
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(1, 0);
column_result->get_element(0) =
constant_constant(a, b, null_map->get_element(0), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
column_result->get_element(0) =
constant_constant(a, b, type_left, type_right, type_result, max_result_number,
scale_diff_multiplier);
return column_result;
}
}
static ColumnPtr adapt_decimal_vector_constant(ColumnPtr column_left, B b,
const LeftDataType& type_left,
const RightDataType& type_right,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier,
DataTypePtr res_data_type) {
auto type_result =
assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type);
auto column_left_ptr =
check_and_get_column<typename Traits::ColumnVectorA>(column_left.get());
auto column_result = ColumnDecimal<ResultType>::create(
column_left->size(),
assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type)
.get_scale());
DCHECK(column_left_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"adapt_decimal_vector_constant Invalid function type!");
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(column_left->size(), 0);
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
null_map->get_data(), column_left->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_constant(column_left_ptr->get_data().data(), b, column_result->get_data().data(),
type_left, type_right, type_result, column_left->size(),
max_result_number, scale_diff_multiplier);
return column_result;
}
}
static ColumnPtr adapt_decimal_constant_vector(A a, ColumnPtr column_right,
const LeftDataType& type_left,
const RightDataType& type_right,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier,
DataTypePtr res_data_type) {
auto type_result =
assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type);
auto column_right_ptr =
check_and_get_column<typename Traits::ColumnVectorB>(column_right.get());
auto column_result = ColumnDecimal<ResultType>::create(
column_right->size(),
assert_cast<const DataTypeDecimal<ResultType>&, TypeCheckOnRelease::DISABLE>(
*res_data_type)
.get_scale());
DCHECK(column_right_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"adapt_decimal_constant_vector Invalid function type!");
return column_result;
} else if constexpr (is_to_null_type) {
auto null_map = ColumnUInt8::create(column_right->size(), 0);
constant_vector(a, column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_right->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
constant_vector(a, column_right_ptr->get_data().data(),
column_result->get_data().data(), type_left, type_right, type_result,
column_right->size(), max_result_number, scale_diff_multiplier);
return column_result;
}
}
static ColumnPtr adapt_decimal_vector_vector(ColumnPtr column_left, ColumnPtr column_right,
const LeftDataType& type_left,
const RightDataType& type_right,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier,
DataTypePtr res_data_type) {
auto column_left_ptr =
check_and_get_column<typename Traits::ColumnVectorA>(column_left.get());
auto column_right_ptr =
check_and_get_column<typename Traits::ColumnVectorB>(column_right.get());
const auto& type_result = assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type);
auto column_result =
ColumnDecimal<ResultType>::create(column_left->size(), type_result.get_scale());
DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
if constexpr (check_overflow && !is_to_null_type &&
((!OpTraits::is_multiply && !OpTraits::is_plus_minus))) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"adapt_decimal_vector_vector Invalid function type!");
return column_result;
} else if constexpr (is_to_null_type) {
// function divide, modulo and pmod
auto null_map = ColumnUInt8::create(column_result->size(), 0);
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
column_result->get_data().data(), null_map->get_data(),
column_left->size(), max_result_number);
return ColumnNullable::create(std::move(column_result), std::move(null_map));
} else {
vector_vector(column_left_ptr->get_data().data(), column_right_ptr->get_data().data(),
column_result->get_data().data(), type_left, type_right, type_result,
column_left->size(), max_result_number, scale_diff_multiplier);
return column_result;
}
}
private:
/// there's implicit type conversion here
template <bool need_adjust_scale>
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
const LeftDataType& type_left,
const RightDataType& type_right,
const ResultDataType& type_result,
const ResultType& max_result_number,
const ResultType& scale_diff_multiplier) {
static_assert(OpTraits::is_plus_minus || OpTraits::is_multiply);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
// Now, Doris only support decimal +-*/ decimal.
if constexpr (check_overflow) {
auto res = Op::apply(DecimalV2Value(a), DecimalV2Value(b)).value();
if (res > max_result_number.value || res < -max_result_number.value) {
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
DecimalV2Value(a).to_string(), Name::name,
DecimalV2Value(b).to_string(), DecimalV2Value(res).to_string(),
ResultDataType {}.get_name());
}
return res;
} else {
return Op::apply(DecimalV2Value(a), DecimalV2Value(b)).value();
}
} else {
NativeResultType res;
if constexpr (OpTraits::can_overflow && check_overflow) {
// TODO handle overflow gracefully
if (UNLIKELY(Op::template apply<NativeResultType>(a, b, res))) {
if constexpr (OpTraits::is_plus_minus) {
auto result_str =
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
type_result.get_scale()}
.to_string(Decimal256(res));
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
result_str, type_result.get_name());
}
// multiply
if constexpr (std::is_same_v<NativeResultType, __int128>) {
wide::Int256 res256 = Op::template apply<wide::Int256>(a, b);
if constexpr (OpTraits::is_multiply && need_adjust_scale) {
if (res256 > 0) {
res256 = (res256 + scale_diff_multiplier.value / 2) /
scale_diff_multiplier.value;
} else {
res256 = (res256 - scale_diff_multiplier.value / 2) /
scale_diff_multiplier.value;
}
}
// check if final result is overflow
if (res256 > wide::Int256(max_result_number.value) ||
res256 < wide::Int256(-max_result_number.value)) {
auto result_str =
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
type_result.get_scale()}
.to_string(Decimal256(res256));
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
type_left.to_string(A(a)), Name::name,
type_right.to_string(B(b)), result_str, type_result.get_name());
} else {
res = res256;
}
} else {
auto result_str =
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
type_result.get_scale()}
.to_string(Decimal256(res));
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
result_str, type_result.get_name());
}
} else {
// round to final result precision
if constexpr (OpTraits::is_multiply && need_adjust_scale) {
if (res >= 0) {
res = (res + scale_diff_multiplier.value / 2) /
scale_diff_multiplier.value;
} else {
res = (res - scale_diff_multiplier.value / 2) /
scale_diff_multiplier.value;
}
}
if (res > max_result_number.value || res < -max_result_number.value) {
auto result_str =
DataTypeDecimal<Decimal256> {BeConsts::MAX_DECIMAL256_PRECISION,
type_result.get_scale()}
.to_string(Decimal256(res));
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
type_left.to_string(A(a)), Name::name, type_right.to_string(B(b)),
result_str, type_result.get_name());
}
}
return res;
} else {
res = Op::template apply<NativeResultType>(a, b);
if constexpr (OpTraits::is_multiply && need_adjust_scale) {
if (res >= 0) {
res = (res + scale_diff_multiplier.value / 2) / scale_diff_multiplier.value;
} else {
res = (res - scale_diff_multiplier.value / 2) / scale_diff_multiplier.value;
}
}
return res;
}
}
}
/// null_map for divide and mod
static ALWAYS_INLINE NativeResultType apply(NativeResultType a, NativeResultType b,
UInt8& is_null,
const ResultType& max_result_number) {
static_assert(OpTraits::is_division || OpTraits::is_mod);
if constexpr (IsDecimalV2<B> || IsDecimalV2<A>) {
DecimalV2Value l(a);
DecimalV2Value r(b);
auto ans = Op::apply(l, r, is_null);
using ANS_TYPE = std::decay_t<decltype(ans)>;
if constexpr (check_overflow && OpTraits::is_division) {
if constexpr (std::is_same_v<ANS_TYPE, DecimalV2Value>) {
if (ans.value() > max_result_number.value ||
ans.value() < -max_result_number.value) {
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
DecimalV2Value(a).to_string(), Name::name,
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
ResultDataType {}.get_name());
}
} else if constexpr (IsDecimalNumber<ANS_TYPE>) {
if (ans.value > max_result_number.value ||
ans.value < -max_result_number.value) {
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
DecimalV2Value(a).to_string(), Name::name,
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
ResultDataType {}.get_name());
}
} else {
if (ans > max_result_number.value || ans < -max_result_number.value) {
THROW_DECIMAL_BINARY_OP_OVERFLOW_EXCEPTION(
DecimalV2Value(a).to_string(), Name::name,
DecimalV2Value(b).to_string(), DecimalV2Value(ans).to_string(),
ResultDataType {}.get_name());
}
}
}
NativeResultType result {};
memcpy(&result, &ans, std::min(sizeof(result), sizeof(ans)));
return result;
} else {
return Op::template apply<NativeResultType>(a, b, is_null);
}
}
};