cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h (603 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "SparkFunctionDecimalBinaryOperator.h" #include <Columns/ColumnDecimal.h> #include <Columns/ColumnNullable.h> #include <Columns/ColumnsNumber.h> #include <Core/DecimalFunctions.h> #include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypesDecimal.h> #include <Functions/FunctionFactory.h> #include <Functions/FunctionHelpers.h> #include <Functions/IFunction.h> #include <Functions/castTypeToEither.h> #include <Common/CurrentThread.h> #if USE_EMBEDDED_COMPILER #include <DataTypes/Native.h> #include <llvm/IR/IRBuilder.h> #endif namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_COLUMN; extern const int TYPE_MISMATCH; extern const int LOGICAL_ERROR; } } namespace local_engine { template <typename Op1, typename Op2> struct IsSameOperation { static constexpr bool value = std::is_same_v<Op1, Op2>; }; template <typename Op> struct SparkIsOperation { static constexpr bool plus = IsSameOperation<Op, DecimalPlusImpl>::value; static constexpr bool minus = IsSameOperation<Op, DecimalMinusImpl>::value; static constexpr bool plus_minus = IsSameOperation<Op, DecimalPlusImpl>::value || IsSameOperation<Op, DecimalMinusImpl>::value; static constexpr bool multiply = IsSameOperation<Op, DecimalMultiplyImpl>::value; static constexpr bool division = IsSameOperation<Op, DecimalDivideImpl>::value; static constexpr bool modulo = IsSameOperation<Op, DecimalModuloImpl>::value; }; using namespace DB; namespace { enum class OpCase : uint8_t { Vector, LeftConstant, RightConstant }; enum class OpMode : uint8_t { Default, Effect }; template <typename Operation, OpMode Mode> struct SparkDecimalBinaryOperation { private: static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus_minus; static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply; static constexpr bool is_division = SparkIsOperation<Operation>::division; static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo; public: static size_t getMaxScaled(size_t left_scale, size_t right_scale, size_t result_scale) { if constexpr (is_multiply) return left_scale + right_scale; else return std::max(result_scale, std::max(left_scale, right_scale)); } template <typename LeftDataType, typename RightDataType, typename ResultDataType> static bool shouldPromoteTo256(const LeftDataType & left_type, const RightDataType & right_type, const ResultDataType & result_type) { auto p1 = left_type.getPrecision(); auto s1 = left_type.getScale(); auto p2 = right_type.getPrecision(); auto s2 = right_type.getScale(); size_t precision; if constexpr (is_plus_minus) precision = std::max<size_t>(s1, s2) + std::max<size_t>(p1 - s1, p2 - s2) + 1; else if constexpr (is_multiply) precision = p1 + p2 + 1; else if constexpr (is_division) precision = p1 - s1 + s2 + std::max<size_t>(6, s1 + p2 + 1); else if constexpr (is_modulo) precision = std::min<size_t>(p1 - s1, p2 - s2) + std::max<size_t>(s1, s2); else throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown decimal binary operation"); if (precision > DataTypeDecimal128::maxPrecision()) return true; return false; } template <typename LeftDataType, typename RightDataType, typename ResultDataType> static ColumnPtr executeDecimal( const ColumnsWithTypeAndName & arguments, const LeftDataType & left_type, const RightDataType & right_type, const ResultDataType & result_type) { using LeftFieldType = typename LeftDataType::FieldType; using RightFieldType = typename RightDataType::FieldType; using ResultFieldType = typename ResultDataType::FieldType; using ColVecLeft = ColumnDecimal<LeftFieldType>; using ColVecRight = ColumnDecimal<RightFieldType>; ColumnPtr col_left = arguments[0].column; ColumnPtr col_right = arguments[1].column; const ColumnConst * col_left_const = checkAndGetColumnConst<ColVecLeft>(col_left.get()); const ColumnConst * col_right_const = checkAndGetColumnConst<ColVecRight>(col_right.get()); const ColVecLeft * col_left_vec = checkAndGetColumn<ColVecLeft>(col_left.get()); const ColVecRight * col_right_vec = checkAndGetColumn<ColVecRight>(col_right.get()); size_t rows = col_left->size(); size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale()); bool calculate_with_i256 = false; if constexpr (Mode != OpMode::Effect) { if (shouldPromoteTo256(left_type, right_type, result_type)) calculate_with_i256 = true; if (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()) calculate_with_i256 = true; } auto p1 = left_type.getPrecision(); auto p2 = right_type.getPrecision(); if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale() || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale()) calculate_with_i256 = true; if (calculate_with_i256) { /// Use Int256 for calculation return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int256>( left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } else if constexpr (is_division) { /// Use Int128 for calculation return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int128>( left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } else { /// Use ResultNativeType for calculation return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, NativeType<ResultFieldType>>( left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } } private: template <typename LeftDataType, typename RightDataType, typename ResultDataType, typename ScaledNativeType> static ColumnPtr executeDecimalImpl( const LeftDataType & left_type, const RightDataType & right_type, const ColumnConst * col_left_const, const ColumnConst * col_right_const, const ColumnDecimal<typename LeftDataType::FieldType> * col_left_vec, const ColumnDecimal<typename RightDataType::FieldType> * col_right_vec, size_t rows, const ResultDataType & result_type) { using LeftFieldType = typename LeftDataType::FieldType; using RightFieldType = typename RightDataType::FieldType; using ResultFieldType = typename ResultDataType::FieldType; using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>; size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale()); ScaledNativeType scale_left = [&] { if constexpr (is_multiply) return ScaledNativeType{1}; auto diff = max_scale - left_type.getScale(); if constexpr (is_division) return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff + max_scale); else return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff); }(); ScaledNativeType scale_right = [&] { if constexpr (is_multiply) return ScaledNativeType{1}; else return DecimalUtils::scaleMultiplier<ScaledNativeType>(max_scale - right_type.getScale()); }(); ScaledNativeType unscale_result = [&] { auto result_scale = result_type.getScale(); auto diff = max_scale - result_scale; chassert(diff >= 0); return DecimalUtils::scaleMultiplier<ScaledNativeType>(diff); }(); ScaledNativeType max_value = intExp10OfSize<ScaledNativeType>(result_type.getPrecision()); auto res_vec = ColVecResult::create(rows, result_type.getScale()); auto & res_vec_data = res_vec->getData(); auto res_null_map = ColumnUInt8::create(rows, 0); auto & res_nullmap_data = res_null_map->getData(); if (col_left_vec && col_right_vec) { process<OpCase::Vector>( col_left_vec->getData().data(), col_right_vec->getData().data(), res_vec_data, res_nullmap_data, rows, scale_left, scale_right, unscale_result, max_value); } else if (col_left_const && col_right_vec) { LeftFieldType left_value = col_left_const->getValue<LeftFieldType>(); process<OpCase::LeftConstant>( &left_value, col_right_vec->getData().data(), res_vec_data, res_nullmap_data, rows, scale_left, scale_right, unscale_result, max_value); } else if (col_left_vec && col_right_const) { RightFieldType right_value = col_right_const->getValue<RightFieldType>(); process<OpCase::RightConstant>( col_left_vec->getData().data(), &right_value, res_vec_data, res_nullmap_data, rows, scale_left, scale_right, unscale_result, max_value); } else throw Exception( ErrorCodes::LOGICAL_ERROR, "Unexpected argument types {} {} {}", left_type.getName(), right_type.getName(), result_type.getName()); return ColumnNullable::create(std::move(res_vec), std::move(res_null_map)); } template < OpCase op_case, typename LeftFieldType, typename RightFieldType, typename ResultFieldType, typename ScaledNativeType> static void NO_INLINE process( const LeftFieldType * __restrict left_data, // maybe scalar or vector const RightFieldType * __restrict right_data, // maybe scalar or vector PaddedPODArray<ResultFieldType> & __restrict res_vec_data, // should be vector NullMap & res_nullmap_data, size_t rows, const ScaledNativeType & scale_left, const ScaledNativeType & scale_right, const ScaledNativeType & unscale_result, const ScaledNativeType & max_value) { using ResultNativeType = NativeType<ResultFieldType>; if constexpr (op_case == OpCase::Vector) { for (size_t i = 0; i < rows; ++i) res_nullmap_data[i] = !calculate( static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, i)), static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, i)), scale_left, scale_right, unscale_result, max_value, res_vec_data[i].value); } else if constexpr (op_case == OpCase::LeftConstant) { ScaledNativeType scaled_left = applyScaled(static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, 0)), scale_left); for (size_t i = 0; i < rows; ++i) res_nullmap_data[i] = !calculate( scaled_left, static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, i)), static_cast<ScaledNativeType>(1), scale_right, unscale_result, max_value, res_vec_data[i].value); } else if constexpr (op_case == OpCase::RightConstant) { ScaledNativeType scaled_right = applyScaled(static_cast<ScaledNativeType>(unwrap<op_case == OpCase::RightConstant>(right_data, 0)), scale_right); for (size_t i = 0; i < rows; ++i) res_nullmap_data[i] = !calculate( static_cast<ScaledNativeType>(unwrap<op_case == OpCase::LeftConstant>(left_data, i)), scaled_right, scale_left, static_cast<ScaledNativeType>(1), unscale_result, max_value, res_vec_data[i].value); } } template < typename ScaledNativeType, typename ResultNativeType> static ALWAYS_INLINE bool calculate( const ScaledNativeType & left, const ScaledNativeType & right, const ScaledNativeType & scale_left, const ScaledNativeType & scale_right, const ScaledNativeType & unscale_result, const ScaledNativeType & max_value, ResultNativeType & res) { auto scaled_left = scale_left > 1 ? applyScaled(left, scale_left) : left; auto scaled_right = scale_right > 1 ? applyScaled(right, scale_right) : right; ScaledNativeType c_res = 0; auto success = Operation::template apply<>(scaled_left, scaled_right, c_res); if (!success) return false; if (unscale_result > 1) c_res = applyUnscaled(c_res, unscale_result); res = static_cast<ResultNativeType>(c_res); if constexpr (std::is_same_v<ScaledNativeType, Int256> || is_division) return c_res > -max_value && c_res < max_value; else return true; } /// Unwrap underlying native type from decimal type template <bool is_scalar, typename E> static auto unwrap(const E * elem, size_t i) { if constexpr (is_scalar) return elem->value; else return elem[i].value; } template <typename T> static ALWAYS_INLINE T applyScaled(T n, T scale) { chassert(scale != 0); T res; DecimalMultiplyImpl::apply(n, scale, res); return res; } template <typename T> static ALWAYS_INLINE T applyUnscaled(T n, T scale) { chassert(scale != 0); T res; DecimalDivideImpl::apply(n, scale, res); return res; } }; template <class Operation, typename Name, OpMode mode = OpMode::Default> class SparkFunctionDecimalBinaryArithmetic final : public IFunction { static constexpr bool is_plus = SparkIsOperation<Operation>::plus; static constexpr bool is_minus = SparkIsOperation<Operation>::minus; static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus || SparkIsOperation<Operation>::minus; static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply; static constexpr bool is_division = SparkIsOperation<Operation>::division; static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo; public: static constexpr auto name = Name::name; static FunctionPtr create(ContextPtr context_) { return std::make_shared<SparkFunctionDecimalBinaryArithmetic>(context_); } explicit SparkFunctionDecimalBinaryArithmetic(ContextPtr context_) : context(context_) { } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 3; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {2}; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { if (arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' expects 3 arguments", getName()); if (!isDecimal(arguments[0]) || !isDecimal(arguments[1]) || !isDecimal(arguments[2])) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} {} {} of argument of function {}", arguments[0]->getName(), arguments[1]->getName(), arguments[2]->getName(), getName()); return makeNullable(arguments[2]); } // executeImpl2 ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override { const auto & left_argument = arguments[0]; const auto & right_argument = arguments[1]; const auto * left_generic = left_argument.type.get(); const auto * right_generic = right_argument.type.get(); ColumnPtr res; bool valid = castTripleTypes( left_generic, right_generic, removeNullable(arguments[2].type).get(), [&](const auto & left, const auto & right, const auto & result) { return (res = SparkDecimalBinaryOperation<Operation, mode>::template executeDecimal<>(arguments, left, right, result)) != nullptr; }); if (!valid) { // This is a logical error, because the types should have been checked // by getReturnTypeImpl(). throw Exception( ErrorCodes::LOGICAL_ERROR, "Arguments of '{}' have incorrect data types: '{}' of type '{}'," " '{}' of type '{}'", getName(), left_argument.name, left_argument.type->getName(), right_argument.name, right_argument.type->getName()); } return res; } #if USE_EMBEDDED_COMPILER virtual ColumnNumbers getArgumentsThatDontParticipateInCompilation(const DataTypes & /*types*/) const { return {2}; } bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override { const auto & denull_left_type = arguments[0]; const auto & denull_right_type = arguments[1]; const auto & denull_result_type = removeNullable(result_type); if (!canBeNativeType(denull_left_type) || !canBeNativeType(denull_right_type) || !canBeNativeType(denull_result_type)) return false; return castTripleTypes( denull_left_type.get(), denull_right_type.get(), denull_result_type.get(), [&](const auto & left_type, const auto & right_type, const auto & result_type) { using LeftDataType = std::decay_t<decltype(left_type)>; using RightDataType = std::decay_t<decltype(right_type)>; using ResultDataType = std::decay_t<decltype(result_type)>; using LeftFieldType = typename LeftDataType::FieldType; using RightFieldType = typename RightDataType::FieldType; using ResultFieldType = typename ResultDataType::FieldType; size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled( left_type.getScale(), right_type.getScale(), result_type.getScale()); auto p1 = left_type.getPrecision(); auto p2 = right_type.getPrecision(); if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale() || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale()) return false; if (SparkDecimalBinaryOperation<Operation, mode>::shouldPromoteTo256(left_type, right_type, result_type) || (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision())) return false; return true; }); } llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override { const auto & denull_left_type = arguments[0].type; const auto & denull_right_type = arguments[1].type; const auto & denull_result_type = removeNullable(result_type); llvm::Value * nullable_result = nullptr; castTripleTypes( denull_left_type.get(), denull_right_type.get(), denull_result_type.get(), [&](const auto & left_type, const auto & right_type, const auto & result_type) { using LeftDataType = std::decay_t<decltype(left_type)>; using RightDataType = std::decay_t<decltype(right_type)>; using ResultDataType = std::decay_t<decltype(result_type)>; using LeftFieldType = typename LeftDataType::FieldType; using RightFieldType = typename RightDataType::FieldType; using ResultFieldType = typename ResultDataType::FieldType; using LeftNativeType = NativeType<LeftFieldType>; using RightNativeType = NativeType<RightFieldType>; using ResultNativeType = NativeType<ResultFieldType>; size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled( left_type.getScale(), right_type.getScale(), result_type.getScale()); auto p1 = left_type.getPrecision(); auto p2 = right_type.getPrecision(); bool calculate_with_256 = false; if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + max_scale - left_type.getScale() || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + max_scale - right_type.getScale()) calculate_with_256 = true; if (SparkDecimalBinaryOperation<Operation, mode>::shouldPromoteTo256(left_type, right_type, result_type) || (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()) || calculate_with_256) nullable_result = compileHelper<Int256>(builder, arguments, left_type, right_type, result_type); // nullable_result = compileHelper<Int128>(builder, arguments, left_type, right_type, result_type); else if (is_division) nullable_result = compileHelper<Int128>(builder, arguments, left_type, right_type, result_type); else nullable_result = compileHelper<ResultNativeType>(builder, arguments, left_type, right_type, result_type); return true; }); return nullable_result; } template <typename CalculateType, typename LeftDataType, typename RightDataType, typename ResultDataType> static llvm::Value * compileHelper( llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const LeftDataType & left_type, const RightDataType & right_type, const ResultDataType & result_type) { auto & b = static_cast<llvm::IRBuilder<> &>(builder); DataTypePtr calculate_type = std::make_shared<DataTypeNumber<CalculateType>>(); auto * left = nativeCast(b, arguments[0], calculate_type); auto * right = nativeCast(b, arguments[1], calculate_type); size_t max_scale = SparkDecimalBinaryOperation<Operation, mode>::getMaxScaled( left_type.getScale(), right_type.getScale(), result_type.getScale()); CalculateType scale_left = [&] { if constexpr (is_multiply) return CalculateType{1}; auto diff = max_scale - left_type.getScale(); if constexpr (is_division) return DecimalUtils::scaleMultiplier<CalculateType>(diff + max_scale); else return DecimalUtils::scaleMultiplier<CalculateType>(diff); }(); CalculateType scale_right = [&] { if constexpr (is_multiply) return CalculateType{1}; else return DecimalUtils::scaleMultiplier<CalculateType>(max_scale - right_type.getScale()); }(); auto * scaled_left = b.CreateMul(left, getNativeConstant(b, scale_left)); auto * scaled_right = b.CreateMul(right, getNativeConstant(b, scale_right)); llvm::Value * scaled_result = nullptr; llvm::Value * is_null = llvm::ConstantInt::getFalse(b.getContext()); if constexpr (is_plus) scaled_result = b.CreateAdd(scaled_left, scaled_right); else if constexpr (is_minus) scaled_result = b.CreateSub(scaled_left, scaled_right); else if constexpr (is_multiply) scaled_result = b.CreateMul(scaled_left, scaled_right); else if constexpr (is_division) { auto * zero = getNativeConstant(b, static_cast<CalculateType>(0)); auto * is_zero = b.CreateICmpEQ(scaled_right, zero); scaled_result = b.CreateSDiv(scaled_left, scaled_right); is_null = is_zero; } else if constexpr (is_modulo) { auto * zero = getNativeConstant(b, static_cast<CalculateType>(0)); auto * is_zero = b.CreateICmpEQ(scaled_right, zero); scaled_result = b.CreateSRem(scaled_left, scaled_right); is_null = is_zero; } auto result_scale = result_type.getScale(); auto scale_diff = max_scale - result_scale; auto * unscaled_result = scaled_result; if (scale_diff) { auto scaled_diff = DecimalUtils::scaleMultiplier<CalculateType>(scale_diff); unscaled_result = b.CreateSDiv(scaled_result, getNativeConstant(b, scaled_diff)); } /// check overflow if constexpr (std::is_same_v<CalculateType, Int256> || is_division) { auto max_value = intExp10OfSize<CalculateType>(result_type.getPrecision()); auto * max_value_const = getNativeConstant(b, max_value); auto * is_overflow = b.CreateOr( b.CreateICmpSGE(unscaled_result, max_value_const), b.CreateICmpSLE(unscaled_result, b.CreateNeg(max_value_const))); auto * overflow_result = getNativeConstant(b, static_cast<CalculateType>(0)); is_null = b.CreateOr(is_null, is_overflow); } auto * result = nativeCast(b, calculate_type, unscaled_result, result_type.getPtr()); auto * nullable_type = toNativeType(b, makeNullable(result_type.getPtr())); auto * nullable_result = llvm::Constant::getNullValue(nullable_type); auto * nullablel_result_with_value = b.CreateInsertValue(nullable_result, result, {0}); return b.CreateInsertValue(nullablel_result_with_value, is_null, {1}); } template <is_integer T> static llvm::Constant * getNativeConstant(llvm::IRBuilderBase & builder, T element) { auto * type = llvm::Type::getIntNTy(builder.getContext(), sizeof(T) * 8); if constexpr (std::is_integral_v<T>) { return llvm::ConstantInt::get(type, static_cast<uint64_t>(element), true); } else { llvm::APInt value(type->getIntegerBitWidth(), element.items); return llvm::ConstantInt::get(type, value); } } #endif // USE_EMBEDDED_COMPILER private: template <typename F> static bool castTripleTypes(const IDataType * left, const IDataType * right, const IDataType * result, F && f) { return castType( left, [&](const auto & left_) { return castType( right, [&](const auto & right_) { return castType(result, [&](const auto & result_) { return f(left_, right_, result_); }); }); }); } static bool castType(const IDataType * type, auto && f) { using Types = TypeList<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128, DataTypeDecimal256>; return castTypeToEither(Types{}, type, std::forward<decltype(f)>(f)); } ContextPtr context; }; } }