cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp (302 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Core/Settings.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionHelpers.h>
#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Common/BlockTypeUtils.h>
#include <Common/GlutenSettings.h>
namespace DB::ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace local_engine
{
using namespace DB;
class DecimalType
{
static constexpr Int32 spark_max_precision = 38;
static constexpr Int32 spark_max_scale = 38;
static constexpr Int32 minimum_adjusted_scale = 6;
static constexpr Int32 chickhouse_max_precision = DB::DataTypeDecimal256::maxPrecision();
static constexpr Int32 chickhouse_max_scale = DB::DataTypeDecimal128::maxPrecision();
public:
Int32 precision;
Int32 scale;
private:
static DecimalType bounded_to_click_house(const Int32 precision, const Int32 scale)
{
return DecimalType(std::min(precision, chickhouse_max_precision), std::min(scale, chickhouse_max_scale));
}
public:
static DecimalType evalAddSubstractDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2)
{
const Int32 scale = s1;
const Int32 precision = scale + std::max(p1 - s1, p2 - s2) + 1;
return bounded_to_click_house(precision, scale);
}
static DecimalType evalDividetDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2)
{
const Int32 scale = std::max(minimum_adjusted_scale, s1 + p2 + 1);
const Int32 precision = p1 - s1 + s2 + scale;
return bounded_to_click_house(precision, scale);
}
static DecimalType evalModuloDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2)
{
const Int32 scale = std::max(s1, s2);
const Int32 precision = std::min(p1 - s1, p2 - s2) + scale;
return bounded_to_click_house(precision, scale);
}
static DecimalType evalMultiplyDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2)
{
const Int32 scale = s1;
const Int32 precision = p1 + p2 + 1;
return bounded_to_click_house(precision, scale);
}
};
class FunctionParserBinaryArithmetic : public FunctionParser
{
protected:
ActionsDAG::NodeRawConstPtrs convertBinaryArithmeticFunDecimalArgs(
ActionsDAG & actions_dag,
const ActionsDAG::NodeRawConstPtrs & args,
const DecimalType & eval_type,
const substrait::Expression_ScalarFunction & arithmeticFun) const
{
const Int32 precision = eval_type.precision;
const Int32 scale = eval_type.scale;
ActionsDAG::NodeRawConstPtrs new_args;
new_args.reserve(args.size());
ActionsDAG::NodeRawConstPtrs cast_args;
cast_args.reserve(2);
cast_args.emplace_back(args[0]);
DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale);
ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type);
const String type_name = ch_type->getName();
const DataTypePtr str_type = std::make_shared<DataTypeString>();
const ActionsDAG::Node * type_node
= &actions_dag.addColumn(ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name)));
cast_args.emplace_back(type_node);
const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args);
actions_dag.addOrReplaceInOutputs(*cast_node);
new_args.emplace_back(cast_node);
new_args.emplace_back(args[1]);
return new_args;
}
DecimalType getDecimalType(const DataTypePtr & left, const DataTypePtr & right) const
{
assert(isDecimal(left) && isDecimal(right));
const Int32 p1 = getDecimalPrecision(*left);
const Int32 s1 = getDecimalScale(*left);
const Int32 p2 = getDecimalPrecision(*right);
const Int32 s2 = getDecimalScale(*right);
return internalEvalType(p1, s1, p2, s2);
}
virtual DecimalType internalEvalType(Int32 p1, Int32 s1, Int32 p2, Int32 s2) const = 0;
const ActionsDAG::Node *
checkDecimalOverflow(ActionsDAG & actions_dag, const ActionsDAG::Node * func_node, Int32 precision, Int32 scale) const
{
//TODO: checkDecimalOverflowSpark throw exception per configuration
const DB::ActionsDAG::NodeRawConstPtrs overflow_args
= {func_node,
expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision),
expression_parser->addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale)};
return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", overflow_args);
}
virtual const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
DataTypePtr result_type) const
{
return toFunctionNode(actions_dag, func_name, args);
}
public:
explicit FunctionParserBinaryArithmetic(ParserContextPtr parser_context_) : FunctionParser(parser_context_) { }
const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAG & actions_dag) const override
{
const auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());
const auto left_type = DB::removeNullable(parsed_args[0]->result_type);
const auto right_type = DB::removeNullable(parsed_args[1]->result_type);
const auto result_type = removeNullable(TypeParser::parseType(substrait_func.output_type()));
const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args, result_type);
return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
}
};
class FunctionParserPlus final : public FunctionParserBinaryArithmetic
{
public:
explicit FunctionParserPlus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { }
static constexpr auto name = "add";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "plus"; }
protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
{
return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
}
const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & new_args,
DataTypePtr result_type) const override
{
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type)))
{
const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(
result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName())));
const auto & settings = parser_context->queryContext()->getSettingsRef();
auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
? "sparkDecimalPlusEffect"
: "sparkDecimalPlus";
return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node});
}
return toFunctionNode(actions_dag, "plus", {left_arg, right_arg});
}
};
class FunctionParserMinus final : public FunctionParserBinaryArithmetic
{
public:
explicit FunctionParserMinus(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { }
static constexpr auto name = "subtract";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "minus"; }
protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
{
return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
}
const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & new_args,
DataTypePtr result_type) const override
{
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type)))
{
const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(
result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName())));
const auto & settings = parser_context->queryContext()->getSettingsRef();
auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
? "sparkDecimalMinusEffect"
: "sparkDecimalMinus";
return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node});
}
return toFunctionNode(actions_dag, "minus", {left_arg, right_arg});
}
};
class FunctionParserMultiply final : public FunctionParserBinaryArithmetic
{
public:
explicit FunctionParserMultiply(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { }
static constexpr auto name = "multiply";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "multiply"; }
protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
{
return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2);
}
const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & new_args,
DataTypePtr result_type) const override
{
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) && isDecimal(removeNullable(right_arg->result_type)))
{
const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(
result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName())));
const auto & settings = parser_context->queryContext()->getSettingsRef();
auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
? "sparkDecimalMultiplyEffect"
: "sparkDecimalMultiply";
return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node});
}
return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg});
}
};
class FunctionParserModulo final : public FunctionParserBinaryArithmetic
{
public:
explicit FunctionParserModulo(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { }
static constexpr auto name = "modulus";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "modulo"; }
protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
{
return DecimalType::evalModuloDecimalType(p1, s1, p2, s2);
}
const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & new_args,
DataTypePtr result_type) const override
{
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type)))
{
const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(
result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName())));
const auto & settings = parser_context->queryContext()->getSettingsRef();
auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
? "sparkDecimalModuloEffect"
: "sparkDecimalModulo";
;
return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node});
}
return toFunctionNode(actions_dag, "spark_modulo", {left_arg, right_arg});
}
};
class FunctionParserDivide final : public FunctionParserBinaryArithmetic
{
public:
explicit FunctionParserDivide(ParserContextPtr parser_context_) : FunctionParserBinaryArithmetic(parser_context_) { }
static constexpr auto name = "divide";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func) const override { return "divide"; }
protected:
DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override
{
return DecimalType::evalDividetDecimalType(p1, s1, p2, s2);
}
const DB::ActionsDAG::Node * createFunctionNode(
DB::ActionsDAG & actions_dag,
const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & new_args,
DataTypePtr result_type) const override
{
assert(func_name == name);
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type)))
{
const ActionsDAG::Node * type_node = &actions_dag.addColumn(ColumnWithTypeAndName(
result_type->createColumnConstWithDefaultValue(1), result_type, getUniqueName(result_type->getName())));
const auto & settings = parser_context->queryContext()->getSettingsRef();
auto function_name = settings.has("arithmetic.decimal.mode") && settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
? "sparkDecimalDivideEffect"
: "sparkDecimalDivide";
;
return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node});
}
return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg});
}
};
static FunctionParserRegister<FunctionParserPlus> register_plus;
static FunctionParserRegister<FunctionParserMinus> register_minus;
static FunctionParserRegister<FunctionParserMultiply> register_mltiply;
static FunctionParserRegister<FunctionParserDivide> register_divide;
static FunctionParserRegister<FunctionParserModulo> register_modulo;
}