cpp-ch/local-engine/Parser/ExpressionParser.cpp (846 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ExpressionParser.h"
#include <Columns/ColumnSet.h>
#include <Core/Settings.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeSet.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <DataTypes/getLeastSupertype.h>
#include <IO/WriteBufferFromString.h>
#include <Parser/FunctionParser.h>
#include <Parser/ParserContext.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/SubstraitParserUtils.h>
#include <Parser/TypeParser.h>
#include <Poco/Logger.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_FUNCTION;
extern const int UNKNOWN_TYPE;
extern const int BAD_ARGUMENTS;
}
}
namespace local_engine
{
using namespace DB;
std::pair<DB::DataTypePtr, DB::Field> LiteralParser::parse(const substrait::Expression_Literal & literal)
{
DB::DataTypePtr type;
DB::Field field;
switch (literal.literal_type_case())
{
case substrait::Expression_Literal::kFp64: {
type = std::make_shared<DB::DataTypeFloat64>();
field = literal.fp64();
break;
}
case substrait::Expression_Literal::kFp32: {
type = std::make_shared<DB::DataTypeFloat32>();
field = literal.fp32();
break;
}
case substrait::Expression_Literal::kString: {
type = std::make_shared<DB::DataTypeString>();
field = literal.string();
break;
}
case substrait::Expression_Literal::kBinary: {
type = std::make_shared<DB::DataTypeString>();
field = literal.binary();
break;
}
case substrait::Expression_Literal::kI64: {
type = std::make_shared<DB::DataTypeInt64>();
field = literal.i64();
break;
}
case substrait::Expression_Literal::kI32: {
type = std::make_shared<DB::DataTypeInt32>();
field = literal.i32();
break;
}
case substrait::Expression_Literal::kBoolean: {
type = DB::DataTypeFactory::instance().get("Bool");
field = literal.boolean() ? UInt8(1) : UInt8(0);
break;
}
case substrait::Expression_Literal::kI16: {
type = std::make_shared<DB::DataTypeInt16>();
field = literal.i16();
break;
}
case substrait::Expression_Literal::kI8: {
type = std::make_shared<DB::DataTypeInt8>();
field = literal.i8();
break;
}
case substrait::Expression_Literal::kDate: {
type = std::make_shared<DB::DataTypeDate32>();
field = literal.date();
break;
}
case substrait::Expression_Literal::kTimestamp: {
type = std::make_shared<DB::DataTypeDateTime64>(6);
field = DecimalField<DB::DateTime64>(literal.timestamp(), 6);
break;
}
case substrait::Expression_Literal::kDecimal: {
UInt32 precision = literal.decimal().precision();
UInt32 scale = literal.decimal().scale();
const auto & bytes = literal.decimal().value();
if (precision <= DB::DataTypeDecimal32::maxPrecision())
{
type = std::make_shared<DB::DataTypeDecimal32>(precision, scale);
auto value = *reinterpret_cast<const Int32 *>(bytes.data());
field = DecimalField<DB::Decimal32>(value, scale);
}
else if (precision <= DataTypeDecimal64::maxPrecision())
{
type = std::make_shared<DB::DataTypeDecimal64>(precision, scale);
auto value = *reinterpret_cast<const Int64 *>(bytes.data());
field = DecimalField<DB::Decimal64>(value, scale);
}
else if (precision <= DataTypeDecimal128::maxPrecision())
{
type = std::make_shared<DB::DataTypeDecimal128>(precision, scale);
String bytes_copy(bytes);
auto value = *reinterpret_cast<DB::Decimal128 *>(bytes_copy.data());
field = DecimalField<DB::Decimal128>(value, scale);
}
else
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision);
break;
}
case substrait::Expression_Literal::kList: {
const auto & values = literal.list().values();
if (values.empty())
{
type = std::make_shared<DataTypeArray>(std::make_shared<DB::DataTypeNothing>());
field = Array();
break;
}
DB::DataTypePtr common_type;
std::tie(common_type, std::ignore) = parse(values[0]);
size_t list_len = values.size();
Array array(list_len);
for (int i = 0; i < static_cast<int>(list_len); ++i)
{
auto type_and_field = parse(values[i]);
common_type = getLeastSupertype(DataTypes{common_type, type_and_field.first});
array[i] = std::move(type_and_field.second);
}
type = std::make_shared<DB::DataTypeArray>(common_type);
field = std::move(array);
break;
}
case substrait::Expression_Literal::kEmptyList: {
type = std::make_shared<DB::DataTypeArray>(std::make_shared<DB::DataTypeNothing>());
field = Array();
break;
}
case substrait::Expression_Literal::kMap: {
const auto & key_values = literal.map().key_values();
if (key_values.empty())
{
type = std::make_shared<DB::DataTypeMap>(std::make_shared<DB::DataTypeNothing>(), std::make_shared<DB::DataTypeNothing>());
field = Map();
break;
}
const auto & first_key_value = key_values[0];
DB::DataTypePtr common_key_type;
std::tie(common_key_type, std::ignore) = parse(first_key_value.key());
DB::DataTypePtr common_value_type;
std::tie(common_value_type, std::ignore) = parse(first_key_value.value());
Map map;
map.reserve(key_values.size());
for (const auto & key_value : key_values)
{
Tuple tuple(2);
DB::DataTypePtr key_type;
std::tie(key_type, tuple[0]) = parse(key_value.key());
/// Each key should has the same type
if (!common_key_type->equals(*key_type))
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"Literal map key type mismatch:{} and {}",
common_key_type->getName(),
key_type->getName());
DB::DataTypePtr value_type;
std::tie(value_type, tuple[1]) = parse(key_value.value());
/// Each value should has least super type for all of them
common_value_type = getLeastSupertype(DB::DataTypes{common_value_type, value_type});
map.emplace_back(std::move(tuple));
}
type = std::make_shared<DB::DataTypeMap>(common_key_type, common_value_type);
field = std::move(map);
break;
}
case substrait::Expression_Literal::kEmptyMap: {
type = std::make_shared<DB::DataTypeMap>(std::make_shared<DB::DataTypeNothing>(), std::make_shared<DB::DataTypeNothing>());
field = Map();
break;
}
case substrait::Expression_Literal::kStruct: {
const auto & fields = literal.struct_().fields();
DB::DataTypes types;
types.reserve(fields.size());
Tuple tuple;
tuple.reserve(fields.size());
for (const auto & f : fields)
{
DB::DataTypePtr field_type;
DB::Field field_value;
std::tie(field_type, field_value) = parse(f);
types.emplace_back(std::move(field_type));
tuple.emplace_back(std::move(field_value));
}
type = std::make_shared<DB::DataTypeTuple>(types);
field = std::move(tuple);
break;
}
case substrait::Expression_Literal::kNull: {
type = TypeParser::parseType(literal.null());
field = DB::Field{};
break;
}
default: {
throw DB::Exception(
DB::ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case()));
}
}
return std::make_pair(std::move(type), std::move(field));
}
const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = "reuse_cse_in_expression_parser";
bool ExpressionParser::reuseCSE() const
{
return context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF, true);
}
ExpressionParser::NodeRawConstPtr
ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr & type, const DB::Field & field) const
{
String name = toString(field).substr(0, 10);
name = getUniqueName(name);
const auto * res_node = &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name));
if (reuseCSE())
{
// The new node, res_node will be remained in the ActionsDAG, but it will not affect the execution.
// And it will be remove once `ActionsDAG::removeUnusedActions` is called.
if (const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag))
res_node = exists_node;
}
return res_node;
}
ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
switch (rel.rex_type_case())
{
case substrait::Expression::RexTypeCase::kLiteral: {
DB::DataTypePtr type;
DB::Field field;
std::tie(type, field) = LiteralParser::parse(rel.literal());
return addConstColumn(actions_dag, type, field);
}
case substrait::Expression::RexTypeCase::kSelection: {
auto field_index = SubstraitParserUtils::getStructFieldIndex(rel);
if (!field_index)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");
const auto * field = actions_dag.getInputs()[*field_index];
return field;
}
case substrait::Expression::RexTypeCase::kCast: {
if (!rel.cast().has_type() || !rel.cast().has_input())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");
ActionsDAG::NodeRawConstPtrs args;
const auto & input = rel.cast().input();
args.emplace_back(parseExpression(actions_dag, input));
const auto & substrait_type = rel.cast().type();
const auto & input_type = args[0]->result_type;
DataTypePtr denull_input_type = removeNullable(input_type);
DataTypePtr output_type = TypeParser::parseType(substrait_type);
DataTypePtr denull_output_type = removeNullable(output_type);
const ActionsDAG::Node * result_node = nullptr;
if (substrait_type.has_binary())
{
/// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args);
}
else if (isString(denull_input_type) && isDate32(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkToDate", args);
else if (isString(denull_input_type) && isDateTime64(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkToDateTime", args);
else if (isDecimal(denull_input_type) && isString(denull_output_type))
{
/// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x, scale)
UInt8 scale = getDecimalScale(*denull_input_type);
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale)));
result_node = toFunctionNode(actions_dag, "toDecimalString", args);
}
else if (isFloat(denull_input_type) && isInt(denull_output_type))
{
String function_name = "sparkCastFloatTo" + denull_output_type->getName();
result_node = toFunctionNode(actions_dag, function_name, args);
}
else if (isFloat(denull_input_type) && isString(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkCastFloatToString", args);
else if ((isDecimal(denull_input_type) || isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
{
int precision = substrait_type.decimal().precision();
int scale = substrait_type.decimal().scale();
if (precision)
{
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision));
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale));
result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
}
}
else if ((isMap(denull_input_type) || isArray(denull_input_type) || isTuple(denull_input_type)) && isString(denull_output_type))
{
/// https://github.com/apache/incubator-gluten/issues/9049
result_node = toFunctionNode(actions_dag, "sparkCastComplexTypesToString", args);
}
else if (isString(denull_input_type) && substrait_type.has_bool_())
{
/// cast(string to boolean)
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args);
}
else if (isString(denull_input_type) && isInt(denull_output_type))
{
/// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT)
/// Refer to https://github.com/apache/incubator-gluten/issues/4956 and https://github.com/apache/incubator-gluten/issues/8598
const auto * trim_str_arg = addConstColumn(actions_dag, std::make_shared<DataTypeString>(), " \t\n\r\f");
args[0] = toFunctionNode(actions_dag, "trimBothSpark", {args[0], trim_str_arg});
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "CAST", args);
}
else
{
/// Common process: CAST(input, type)
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "CAST", args);
}
actions_dag.addOrReplaceInOutputs(*result_node);
return result_node;
}
case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = rel.if_then();
DB::FunctionOverloadResolverPtr function_ptr = nullptr;
auto condition_nums = if_then.ifs_size();
if (condition_nums == 1)
function_ptr = DB::FunctionFactory::instance().get("if", context->queryContext());
else
function_ptr = FunctionFactory::instance().get("multiIf", context->queryContext());
DB::ActionsDAG::NodeRawConstPtrs args;
for (int i = 0; i < condition_nums; ++i)
{
const auto & ifs = if_then.ifs(i);
const auto * if_node = parseExpression(actions_dag, ifs.if_());
args.emplace_back(if_node);
const auto * then_node = parseExpression(actions_dag, ifs.then());
args.emplace_back(then_node);
}
const auto * else_node = parseExpression(actions_dag, if_then.else_());
args.emplace_back(else_node);
std::string args_name = join(args, ',');
std::string result_name;
if (condition_nums == 1)
result_name = "if(" + args_name + ")";
else
result_name = "multiIf(" + args_name + ")";
const auto * function_node = &actions_dag.addFunction(function_ptr, args, result_name);
actions_dag.addOrReplaceInOutputs(*function_node);
return function_node;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
return parseFunction(rel.scalar_function(), actions_dag);
}
case substrait::Expression::RexTypeCase::kSingularOrList: {
const auto & options = rel.singular_or_list().options();
/// options is empty always return false
if (options.empty())
return addConstColumn(actions_dag, std::make_shared<DB::DataTypeUInt8>(), 0);
/// options should be literals
if (!options[0].has_literal())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");
DB::ActionsDAG::NodeRawConstPtrs args;
args.emplace_back(parseExpression(actions_dag, rel.singular_or_list().value()));
bool nullable = false;
int options_len = options.size();
for (int i = 0; i < options_len; ++i)
{
if (!options[i].has_literal())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
if (!nullable)
nullable = options[i].literal().has_null();
}
DB::DataTypePtr elem_type;
std::vector<std::pair<DB::DataTypePtr, DB::Field>> options_type_and_field;
auto first_option = LiteralParser::parse(options[0].literal());
elem_type = wrapNullableType(nullable, first_option.first);
options_type_and_field.emplace_back(std::move(first_option));
for (int i = 1; i < options_len; ++i)
{
auto type_and_field = LiteralParser::parse(options[i].literal());
auto option_type = wrapNullableType(nullable, type_and_field.first);
if (!elem_type->equals(*option_type))
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"SingularOrList options type mismatch:{} and {}",
elem_type->getName(),
option_type->getName());
options_type_and_field.emplace_back(std::move(type_and_field));
}
// check tuple internal types
if (isTuple(elem_type) && isTuple(args[0]->result_type))
{
// Spark guarantees that the types of tuples in the 'in' filter are completely consistent.
// See org.apache.spark.sql.types.DataType#equalsStructurally
// Additionally, the mapping from Spark types to ClickHouse types is one-to-one, See TypeParser.cpp
// So we can directly use the first tuple type as the type of the tuple to avoid nullable mismatch
elem_type = args[0]->result_type;
}
DB::MutableColumnPtr elem_column = elem_type->createColumn();
elem_column->reserve(options_len);
for (int i = 0; i < options_len; ++i)
elem_column->insert(options_type_and_field[i].second);
auto name = getUniqueName("__set");
ColumnWithTypeAndName elem_block{std::move(elem_column), elem_type, name};
PreparedSets prepared_sets;
FutureSet::Hash emptyKey;
auto future_set = prepared_sets.addFromTuple(emptyKey, nullptr, {elem_block}, context->queryContext()->getSettingsRef());
auto arg = DB::ColumnSet::create(1, std::move(future_set));
args.emplace_back(&actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(arg), std::make_shared<DB::DataTypeSet>(), name)));
const auto * function_node = toFunctionNode(actions_dag, "in", args);
actions_dag.addOrReplaceInOutputs(*function_node);
if (nullable)
{
/// if sets has `null` and value not in sets
/// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
/// In CH: return `false`
/// So we used if(a, b, c) cast `false` to `null` if sets has `null`
auto type = wrapNullableType(true, function_node->result_type);
DB::ActionsDAG::NodeRawConstPtrs cast_args(
{function_node, addConstColumn(actions_dag, type, true), addConstColumn(actions_dag, type, DB::Field())});
auto cast = DB::FunctionFactory::instance().get("if", context->queryContext());
function_node = toFunctionNode(actions_dag, "if", cast_args);
actions_dag.addOrReplaceInOutputs(*function_node);
}
return function_node;
}
default:
throw DB::Exception(
DB::ErrorCodes::UNKNOWN_TYPE,
"Unsupported spark expression type {} : {}",
magic_enum::enum_name(rel.rex_type_case()),
rel.DebugString());
}
}
DB::ActionsDAG
ExpressionParser::expressionsToActionsDAG(const std::vector<substrait::Expression> & expressions, const DB::Block & header) const
{
DB::ActionsDAG actions_dag(header.getNamesAndTypesList());
DB::NamesWithAliases required_columns;
std::set<String> distinct_columns;
for (const auto & expr : expressions)
{
if (auto field_index = SubstraitParserUtils::getStructFieldIndex(expr))
{
auto col_name = header.getByPosition(*field_index).name;
const DB::ActionsDAG::Node * field = actions_dag.tryFindInOutputs(col_name);
if (!field)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found {} in actions dag's output", col_name);
if (distinct_columns.contains(field->result_name))
{
auto unique_name = getUniqueName(field->result_name);
required_columns.emplace_back(DB::NameWithAlias(field->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(DB::NameWithAlias(field->result_name, field->result_name));
distinct_columns.emplace(field->result_name);
}
}
else if (expr.has_scalar_function())
{
const auto & scalar_function = expr.scalar_function();
auto signature_name = getFunctionNameInSignature(scalar_function);
std::vector<String> result_names;
if (signature_name == "explode")
{
auto result_nodes = parseArrayJoin(scalar_function, actions_dag, false);
for (const auto * node : result_nodes)
result_names.emplace_back(node->result_name);
}
else if (signature_name == "posexplode")
{
auto result_nodes = parseArrayJoin(scalar_function, actions_dag, true);
for (const auto * node : result_nodes)
result_names.emplace_back(node->result_name);
}
else if (signature_name == "json_tuple")
{
auto result_nodes = parseJsonTuple(scalar_function, actions_dag);
for (const auto * node : result_nodes)
result_names.emplace_back(node->result_name);
}
else
{
result_names.resize(1);
result_names[0] = parseFunction(scalar_function, actions_dag, true)->result_name;
}
for (const auto & result_name : result_names)
{
if (result_name.empty())
continue;
if (distinct_columns.contains(result_name))
{
auto unique_name = getUniqueName(result_name);
required_columns.emplace_back(NameWithAlias(result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(result_name, result_name));
distinct_columns.emplace(result_name);
}
}
}
else if (expr.has_cast() || expr.has_if_then() || expr.has_literal() || expr.has_singular_or_list())
{
const auto * node = parseExpression(actions_dag, expr);
actions_dag.addOrReplaceInOutputs(*node);
if (distinct_columns.contains(node->result_name))
{
auto unique_name = getUniqueName(node->result_name);
required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
distinct_columns.emplace(node->result_name);
}
}
else
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
actions_dag.project(required_columns);
actions_dag.appendInputsForUnusedColumns(header);
return actions_dag;
}
DB::ActionsDAG::NodeRawConstPtrs
ExpressionParser::parseFunctionArguments(DB::ActionsDAG & actions_dag, const substrait::Expression_ScalarFunction & func) const
{
DB::ActionsDAG::NodeRawConstPtrs parsed_args;
parsed_args.reserve(func.arguments_size());
for (Int32 i = 0; i < func.arguments_size(); ++i)
{
const auto & arg = func.arguments(i);
if (!arg.has_value())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow scalar function:{}\n\n{}", func.DebugString(), arg.DebugString());
const auto * node = parseExpression(actions_dag, arg.value());
parsed_args.emplace_back(node);
}
return parsed_args;
}
ExpressionParser::NodeRawConstPtr
ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool add_to_output) const
{
auto function_signature = getFunctionNameInSignature(func);
auto function_parser = FunctionParserFactory::instance().get(function_signature, context);
const auto * function_node = function_parser->parse(func, actions_dag);
if (add_to_output)
actions_dag.addOrReplaceInOutputs(*function_node);
return function_node;
}
ExpressionParser::NodeRawConstPtr ExpressionParser::toFunctionNode(
DB::ActionsDAG & actions_dag,
const String & ch_function_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
const String & result_name_) const
{
auto function_builder = FunctionFactory::instance().get(ch_function_name, context->queryContext());
std::string result_name = result_name_;
if (result_name.empty())
{
std::string args_name = join(args, ',');
result_name = ch_function_name + "(" + args_name + ")";
}
const auto * res_node = &actions_dag.addFunction(function_builder, args, result_name);
if (reuseCSE())
{
const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag);
if (exists_node)
{
if (result_name_.empty() || result_name == exists_node->result_name)
res_node = exists_node;
else
res_node = &actions_dag.addAlias(*exists_node, result_name);
}
}
return res_node;
}
std::atomic<UInt64> ExpressionParser::unique_name_counter = 0;
String ExpressionParser::getUniqueName(const String & name) const
{
return name + "_" + std::to_string(unique_name_counter++);
}
String ExpressionParser::getFunctionNameInSignature(const substrait::Expression_ScalarFunction & func_) const
{
return getFunctionNameInSignature(func_.function_reference());
}
String ExpressionParser::getFunctionNameInSignature(UInt32 func_ref_) const
{
auto function_sig = context->getFunctionNameInSignature(func_ref_);
if (!function_sig)
throw DB::Exception(DB::ErrorCodes::UNKNOWN_FUNCTION, "Unknown function anchor: {}", func_ref_);
return *function_sig;
}
String ExpressionParser::getFunctionName(const substrait::Expression_ScalarFunction & func_) const
{
auto signature_name = getFunctionNameInSignature(func_);
auto function_parser = FunctionParserFactory::instance().tryGet(signature_name, context);
if (!function_parser)
throw DB::Exception(DB::ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", signature_name);
return function_parser->getCHFunctionName(func_);
}
String ExpressionParser::safeGetFunctionName(const substrait::Expression_ScalarFunction & func_) const
{
try
{
return getFunctionName(func_);
}
catch (const DB::Exception &)
{
return "";
}
}
DB::ActionsDAG::NodeRawConstPtrs ExpressionParser::parseArrayJoinArguments(
const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position, bool & is_map) const
{
auto parsed_args = parseFunctionArguments(actions_dag, func);
const auto arg0_type = DB::removeNullable(parsed_args[0]->result_type);
if (isMap(arg0_type))
is_map = true;
else if (isArray(arg0_type))
is_map = false;
else
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "Argument type of arrayJoin should be Array or Map but is {}", arg0_type->getName());
/// Remove Nullable for input argument of arrayJoin function because arrayJoin function only accept non-nullable input
/// array() or map()
const auto * empty_node = addConstColumn(actions_dag, arg0_type, is_map ? DB::Field(Map()) : DB::Field(Array()));
/// ifNull(arg, array()) or ifNull(arg, map())
const auto * if_null_node = toFunctionNode(actions_dag, "ifNull", {parsed_args[0], empty_node});
/// assumeNotNull(ifNull(arg, array())) or assumeNotNull(ifNull(arg, map()))
const auto * not_null_node = toFunctionNode(actions_dag, "assumeNotNull", {if_null_node});
/// Wrap with materalize function to make sure column input to ARRAY JOIN STEP is materaized
const auto * arg = &actions_dag.materializeNode(*not_null_node);
/// If spark function is posexplode, we need to add position column together with input argument
if (position)
{
/// length(arg)
const auto * length_node = toFunctionNode(actions_dag, "length", {arg});
/// range(length(arg))
const auto * range_node = toFunctionNode(actions_dag, "range", {length_node});
/// mapFromArrays(range(length(arg)), arg)
arg = toFunctionNode(actions_dag, "mapFromArrays", {range_node, arg});
}
parsed_args[0] = arg;
return parsed_args;
}
DB::ActionsDAG::NodeRawConstPtrs
ExpressionParser::parseArrayJoin(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position) const
{
/// Whether the input argument of explode/posexplode is map type
bool is_map = false;
auto parsed_args = parseArrayJoinArguments(func, actions_dag, position, is_map);
/// Note: Make sure result_name keep the same after applying arrayJoin function, which makes it much easier to transform arrayJoin function to ARRAY JOIN STEP
/// Otherwise an alias node must be appended after ARRAY JOIN STEP, which is not a graceful implementation.
const auto & arg_not_null = parsed_args[0];
auto array_join_name = arg_not_null->result_name;
/// arrayJoin(arg_not_null)
const auto * array_join_node = &actions_dag.addArrayJoin(*arg_not_null, array_join_name);
auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context->queryContext());
auto tuple_index_type = std::make_shared<DB::DataTypeUInt32>();
auto add_tuple_element = [&](const DB::ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
{
DB::ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
const auto * index_node = &actions_dag.addColumn(std::move(index_col));
auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
return &actions_dag.addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
};
/// Special process to keep compatiable with Spark
if (!position)
{
/// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map)
if (is_map)
{
/// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value"
/// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
/// So we must wrap arrayJoin with sparkTupleElement function for compatiability.
/// arrayJoin(arg_not_null).1
const auto * key_node = add_tuple_element(array_join_node, 1);
/// arrayJoin(arg_not_null).2
const auto * val_node = add_tuple_element(array_join_node, 2);
actions_dag.addOrReplaceInOutputs(*key_node);
actions_dag.addOrReplaceInOutputs(*val_node);
return {key_node, val_node};
}
else
{
actions_dag.addOrReplaceInOutputs(*array_join_node);
return {array_join_node};
}
}
else
{
/// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map)
/// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value)
/// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
/// So we must wrap arrayJoin with sparkTupleElement function for compatiability.
/// pos = cast(arrayJoin(arg_not_null).1, "Int32")
const auto * pos_node = add_tuple_element(array_join_node, 1);
pos_node = ActionsDAGUtil::convertNodeType(actions_dag, pos_node, INT());
/// if is_map is false, output col = arrayJoin(arg_not_null).2
/// if is_map is true, output (key, value) = arrayJoin(arg_not_null).2
const auto * item_node = add_tuple_element(array_join_node, 2);
if (is_map)
{
/// key = arrayJoin(arg_not_null).2.1
const auto * key_node = add_tuple_element(item_node, 1);
/// value = arrayJoin(arg_not_null).2.2
const auto * val_node = add_tuple_element(item_node, 2);
actions_dag.addOrReplaceInOutputs(*pos_node);
actions_dag.addOrReplaceInOutputs(*key_node);
actions_dag.addOrReplaceInOutputs(*val_node);
return {pos_node, key_node, val_node};
}
else
{
actions_dag.addOrReplaceInOutputs(*pos_node);
actions_dag.addOrReplaceInOutputs(*item_node);
return {pos_node, item_node};
}
}
}
DB::ActionsDAG::NodeRawConstPtrs
ExpressionParser::parseJsonTuple(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) const
{
const auto & pb_args = func.arguments();
if (pb_args.size() < 2)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "json_tuple function has at least 2 arguments");
const auto & first_arg = pb_args[0].value();
const auto * json_expr_node = parseExpression(actions_dag, first_arg);
DB::WriteBufferFromOwnString write_buffer;
write_buffer << "Tuple(";
for (int i = 1; i < pb_args.size(); ++i)
{
if (i > 1)
write_buffer << ", ";
const auto & arg = pb_args[i].value();
if (!arg.has_literal() || !arg.literal().has_string())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "json_tuple function requires string literal arguments");
write_buffer << arg.literal().string() << " Nullable(String)";
}
write_buffer << ")";
const auto * extract_expr_node = addConstColumn(actions_dag, std::make_shared<DB::DataTypeString>(), write_buffer.str());
auto json_extract_builder = DB::FunctionFactory::instance().get("JSONExtract", context->queryContext());
auto json_extract_result_name = "JSONExtract(" + json_expr_node->result_name + ", " + extract_expr_node->result_name + ")";
const auto * json_extract_node
= &actions_dag.addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name);
auto tuple_element_builder = DB::FunctionFactory::instance().get("sparkTupleElement", context->queryContext());
auto tuple_index_type = std::make_shared<DB::DataTypeUInt32>();
auto add_tuple_element = [&](const DB::ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
{
DB::ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
const auto * index_node = &actions_dag.addColumn(std::move(index_col));
auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
return &actions_dag.addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
};
DB::ActionsDAG::NodeRawConstPtrs res_nodes;
for (int i = 1; i < pb_args.size(); ++i)
{
const auto * tuple_node = add_tuple_element(json_extract_node, i);
actions_dag.addOrReplaceInOutputs(*tuple_node);
res_nodes.push_back(tuple_node);
}
return res_nodes;
}
static bool isAllowedDataType(const DB::IDataType & data_type)
{
DB::WhichDataType which(data_type);
if (which.isNullable())
{
const auto * null_type = typeid_cast<const DB::DataTypeNullable *>(&data_type);
return isAllowedDataType(*(null_type->getNestedType()));
}
else if (which.isNumber() || which.isStringOrFixedString() || which.isDateOrDate32OrDateTimeOrDateTime64())
return true;
else if (which.isArray())
{
auto nested_type = typeid_cast<const DB::DataTypeArray *>(&data_type)->getNestedType();
return isAllowedDataType(*nested_type);
}
else if (which.isTuple())
{
const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(&data_type);
for (const auto & nested_type : tuple_type->getElements())
if (!isAllowedDataType(*nested_type))
return false;
return true;
}
else if (which.isMap())
{
const auto * map_type = typeid_cast<const DB::DataTypeMap *>(&data_type);
return isAllowedDataType(*(map_type->getKeyType())) && isAllowedDataType(*(map_type->getValueType()));
}
return false;
}
bool ExpressionParser::areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b)
{
if (a == b)
return true;
if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size()
|| !a->isDeterministic() || !b->isDeterministic() || !isAllowedDataType(*(a->result_type)))
return false;
switch (a->type)
{
case DB::ActionsDAG::ActionType::INPUT: {
if (a->result_name != b->result_name)
return false;
break;
}
case DB::ActionsDAG::ActionType::ALIAS: {
if (a->result_name != b->result_name)
return false;
break;
}
case DB::ActionsDAG::ActionType::COLUMN: {
// dummpy columns cannot be compared
if (typeid_cast<const DB::ColumnSet *>(a->column.get()))
return a->result_name == b->result_name;
if (a->column->compareAt(0, 0, *(b->column), 1) != 0)
return false;
break;
}
case DB::ActionsDAG::ActionType::ARRAY_JOIN: {
return false;
}
case DB::ActionsDAG::ActionType::FUNCTION: {
if (!a->function_base->isDeterministic() || a->function_base->getName() != b->function_base->getName())
return false;
break;
}
default: {
LOG_WARNING(
getLogger("ExpressionParser"),
"Unknow node type. type:{}, data type:{}, result_name:{}",
a->type,
a->result_type->getName(),
a->result_name);
return false;
}
}
for (size_t i = 0; i < a->children.size(); ++i)
if (!areEqualNodes(a->children[i], b->children[i]))
return false;
LOG_TEST(
getLogger("ExpressionParser"),
"Nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}",
a->type,
a->result_type->getName(),
a->result_name,
b->type,
b->result_type->getName(),
b->result_name);
return true;
}
// since each new node is added at the end of ActionsDAG::nodes, we expect to find the previous node and the new node will be dropped later.
ExpressionParser::NodeRawConstPtr
ExpressionParser::findFirstStructureEqualNode(NodeRawConstPtr target, const DB::ActionsDAG & actions_dag) const
{
for (const auto & node : actions_dag.getNodes())
{
if (target == &node)
continue;
if (areEqualNodes(target, &node))
{
LOG_TEST(
getLogger("ExpressionParser"),
"Two nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}",
target->type,
target->result_type->getName(),
target->result_name,
node.type,
node.result_type->getName(),
node.result_name);
return &node;
}
}
return nullptr;
}
}