cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp (302 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <Core/Settings.h> #include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypesDecimal.h> #include <Functions/FunctionHelpers.h> #include <Parser/FunctionParser.h> #include <Parser/TypeParser.h> #include <Common/BlockTypeUtils.h> #include <Common/GlutenSettings.h> namespace DB::ErrorCodes { extern const int BAD_ARGUMENTS; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } namespace local_engine { using namespace DB; class DecimalType { static constexpr Int32 spark_max_precision = 38; static constexpr Int32 spark_max_scale = 38; static constexpr Int32 minimum_adjusted_scale = 6; static constexpr Int32 chickhouse_max_precision = DB::DataTypeDecimal256::maxPrecision(); static constexpr Int32 chickhouse_max_scale = DB::DataTypeDecimal128::maxPrecision(); public: Int32 precision; Int32 scale; private: static DecimalType bounded_to_click_house(const Int32 precision, const Int32 scale) { return DecimalType(std::min(precision, chickhouse_max_precision), std::min(scale, chickhouse_max_scale)); } public: static DecimalType evalAddSubstractDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) { const Int32 scale = s1; const Int32 precision = scale + std::max(p1 - s1, p2 - s2) + 1; return bounded_to_click_house(precision, scale); } static DecimalType evalDividetDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) { const Int32 scale = std::max(minimum_adjusted_scale, s1 + p2 + 1); const Int32 precision = p1 - s1 + s2 + scale; return bounded_to_click_house(precision, scale); } static DecimalType evalModuloDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) { const Int32 scale = std::max(s1, s2); const Int32 precision = std::min(p1 - s1, p2 - s2) + scale; return bounded_to_click_house(precision, scale); } static DecimalType evalMultiplyDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) { const Int32 scale = s1; const Int32 precision = p1 + p2 + 1; return bounded_to_click_house(precision, scale); } }; class FunctionParserBinaryArithmetic : public FunctionParser { protected: ActionsDAG::NodeRawConstPtrs convertBinaryArithmeticFunDecimalArgs( ActionsDAG & actions_dag, const ActionsDAG::NodeRawConstPtrs & args, const DecimalType & eval_type, const substrait::Expression_ScalarFunction & arithmeticFun) const { const Int32 precision = eval_type.precision; const Int32 scale = eval_type.scale; ActionsDAG::NodeRawConstPtrs new_args; new_args.reserve(args.size()); ActionsDAG::NodeRawConstPtrs cast_args; cast_args.reserve(2); cast_args.emplace_back(args[0]); DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale); ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type); const String type_name = ch_type->getName(); const DataTypePtr str_type = std::make_shared<DataTypeString>(); const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name))); cast_args.emplace_back(type_node); const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args); actions_dag.addOrReplaceInOutputs(*cast_node); new_args.emplace_back(cast_node); new_args.emplace_back(args[1]); return new_args; } DecimalType getDecimalType(const DataTypePtr & left, const DataTypePtr & right) const { assert(isDecimal(left) && isDecimal(right)); const Int32 p1 = getDecimalPrecision(*left); const Int32 s1 = getDecimalScale(*left); const Int32 p2 = getDecimalPrecision(*right); const Int32 s2 = getDecimalScale(*right); return internalEvalType(p1, s1, p2, s2); } virtual DecimalType internalEvalType(Int32 p1, Int32 s1, Int32 p2, Int32 s2) const = 0; const ActionsDAG::Node * checkDecimalOverflow(ActionsDAG & actions_dag, const ActionsDAG::Node * func_node, Int32 precision, Int32 scale) const { //TODO: checkDecimalOverflowSpark throw exception per configuration const DB::ActionsDAG::NodeRawConstPtrs overflow_args = {func_node, expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision), expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale)}; return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", overflow_args); } virtual const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args, DataTypePtr result_type) const { return toFunctionNode(actions_dag, func_name, args); } public: explicit FunctionParserBinaryArithmetic(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { } const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const override { const auto ch_func_name = getCHFunctionName(substrait_func); auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); const auto left_type = DB::removeNullable(parsed_args[0]->result_type); const auto right_type = DB::removeNullable(parsed_args[1]->result_type); const auto result_type = removeNullable(TypeParser::parseType(substrait_func.output_type())); const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args, result_type); return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); } }; class FunctionParserPlus final : public FunctionParserBinaryArithmetic { public: explicit FunctionParserPlus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } static constexpr auto name = "add"; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "plus"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override { return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); } const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args, DataTypePtr result_type) const override { const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) { const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); const auto & settings = parser_context->queryContext()->getSettingsRef(); auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") ? "sparkDecimalPlusEffect" : "sparkDecimalPlus"; return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } return toFunctionNode(actions_dag, "plus", {left_arg, right_arg}); } }; class FunctionParserMinus final : public FunctionParserBinaryArithmetic { public: explicit FunctionParserMinus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } static constexpr auto name = "subtract"; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "minus"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override { return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); } const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args, DataTypePtr result_type) const override { const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) { const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); const auto & settings = parser_context->queryContext()->getSettingsRef(); auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") ? "sparkDecimalMinusEffect" : "sparkDecimalMinus"; return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } return toFunctionNode(actions_dag, "minus", {left_arg, right_arg}); } }; class FunctionParserMultiply final : public FunctionParserBinaryArithmetic { public: explicit FunctionParserMultiply(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } static constexpr auto name = "multiply"; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "multiply"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override { return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2); } const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args, DataTypePtr result_type) const override { const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type))) { const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); const auto & settings = parser_context->queryContext()->getSettingsRef(); auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") ? "sparkDecimalMultiplyEffect" : "sparkDecimalMultiply"; return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg}); } }; class FunctionParserModulo final : public FunctionParserBinaryArithmetic { public: explicit FunctionParserModulo(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } static constexpr auto name = "modulus"; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "modulo"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override { return DecimalType::evalModuloDecimalType(p1, s1, p2, s2); } const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args, DataTypePtr result_type) const override { const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) { const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); const auto & settings = parser_context->queryContext()->getSettingsRef(); auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") ? "sparkDecimalModuloEffect" : "sparkDecimalModulo"; ; return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } return toFunctionNode(actions_dag, "spark_modulo", {left_arg, right_arg}); } }; class FunctionParserDivide final : public FunctionParserBinaryArithmetic { public: explicit FunctionParserDivide(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { } static constexpr auto name = "divide"; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "divide"; } protected: DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override { return DecimalType::evalDividetDecimalType(p1, s1, p2, s2); } const DB::ActionsDAG::Node * createFunctionNode( DB::ActionsDAG & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args, DataTypePtr result_type) const override { assert(func_name == name); const auto * left_arg = new_args[0]; const auto * right_arg = new_args[1]; if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) { const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName( result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName()))); const auto & settings = parser_context->queryContext()->getSettingsRef(); auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT") ? "sparkDecimalDivideEffect" : "sparkDecimalDivide"; ; return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); } }; static FunctionParserRegister<FunctionParserPlus> register_plus; static FunctionParserRegister<FunctionParserMinus> register_minus; static FunctionParserRegister<FunctionParserMultiply> register_mltiply; static FunctionParserRegister<FunctionParserDivide> register_divide; static FunctionParserRegister<FunctionParserModulo> register_modulo; }