cpp-ch/local-engine/Parser/SerializedPlanParser.cpp (1,971 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "SerializedPlanParser.h" #include <algorithm> #include <memory> #include <string> #include <string_view> #include <AggregateFunctions/AggregateFunctionFactory.h> #include <Columns/ColumnSet.h> #include <Core/Block.h> #include <Core/ColumnWithTypeAndName.h> #include <Core/Names.h> #include <Core/NamesAndTypes.h> #include <Core/Types.h> #include <Core/Field.h> #include <DataTypes/DataTypeAggregateFunction.h> #include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeDate32.h> #include <DataTypes/DataTypeDateTime64.h> #include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeMap.h> #include <DataTypes/DataTypeNothing.h> #include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeSet.h> #include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypesDecimal.h> #include <DataTypes/DataTypesNumber.h> #include <DataTypes/IDataType.h> #include <DataTypes/Serializations/ISerialization.h> #include <DataTypes/getLeastSupertype.h> #include <Functions/FunctionFactory.h> #include <Functions/FunctionsConversion.h> #include <Interpreters/ActionsDAG.h> #include <Interpreters/ActionsVisitor.h> #include <Interpreters/CollectJoinOnKeysVisitor.h> #include <Interpreters/Context.h> #include <Interpreters/PreparedSets.h> #include <Interpreters/ProcessList.h> #include <Interpreters/QueryPriorities.h> #include <Join/StorageJoinFromReadBuffer.h> #include <Operator/BlocksBufferPoolTransform.h> #include <Parser/FunctionParser.h> #include <Parser/JoinRelParser.h> #include <Parser/RelParser.h> #include <Parser/TypeParser.h> #include <Parser/MergeTreeRelParser.h> #include <Parsers/ASTIdentifier.h> #include <Parsers/ExpressionListParsers.h> #include <Processors/Executors/PullingAsyncPipelineExecutor.h> #include <Processors/Formats/Impl/ArrowBlockOutputFormat.h> #include <Processors/QueryPlan/AggregatingStep.h> #include <Processors/QueryPlan/ExpressionStep.h> #include <Processors/QueryPlan/FilterStep.h> #include <Processors/QueryPlan/LimitStep.h> #include <Processors/QueryPlan/MergingAggregatedStep.h> #include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h> #include <Processors/QueryPlan/QueryPlan.h> #include <Processors/QueryPlan/ReadFromPreparedSource.h> #include <Processors/Transforms/AggregatingTransform.h> #include <Processors/Transforms/MaterializingTransform.h> #include <QueryPipeline/Pipe.h> #include <QueryPipeline/QueryPipelineBuilder.h> #include <QueryPipeline/printPipeline.h> #include <Storages/CustomStorageMergeTree.h> #include <Storages/IStorage.h> #include <Storages/MergeTree/MergeTreeData.h> #include <Storages/StorageMergeTreeFactory.h> #include <Storages/SubstraitSource/SubstraitFileSource.h> #include <Storages/SubstraitSource/SubstraitFileSourceStep.h> #include <google/protobuf/util/json_util.h> #include <google/protobuf/wrappers.pb.h> #include <Poco/Util/MapConfiguration.h> #include <Common/CHUtil.h> #include <Common/Exception.h> #include <Common/MergeTreeTool.h> #include <Common/logger_useful.h> #include <Common/typeid_cast.h> namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int UNKNOWN_TYPE; extern const int BAD_ARGUMENTS; extern const int NO_SUCH_DATA_PART; extern const int UNKNOWN_FUNCTION; extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int INVALID_JOIN_ON_EXPRESSION; } } namespace local_engine { using namespace DB; std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c) { std::string res; for (size_t i = 0; i < v.size(); ++i) { if (i) res += c; res += v[i]->result_name; } return res; } void logDebugMessage(const google::protobuf::Message & message, const char * type) { auto * logger = &Poco::Logger::get("SerializedPlanParser"); if (logger->debug()) { namespace pb_util = google::protobuf::util; pb_util::JsonOptions options; std::string json; auto s = pb_util::MessageToJsonString(message, &json, options); if (!s.ok()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert {} to Json", type); LOG_DEBUG(logger, "{}:\n{}", type, json); } } const ActionsDAG::Node * SerializedPlanParser::addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field) { return &actions_dag->addColumn( ColumnWithTypeAndName(type->createColumnConst(1, field), type, getUniqueName(toString(field).substr(0, 10)))); } void SerializedPlanParser::parseExtensions( const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions) { for (const auto & extension : extensions) { if (extension.has_extension_function()) { function_mapping.emplace( std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); } } } std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG( const std::vector<substrait::Expression> & expressions, const DB::Block & header, const DB::Block & read_schema) { auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); NamesWithAliases required_columns; std::set<String> distinct_columns; for (const auto & expr : expressions) { if (expr.has_selection()) { auto position = expr.selection().direct_reference().struct_field().field(); auto col_name = read_schema.getByPosition(position).name; const ActionsDAG::Node * field = actions_dag->tryFindInOutputs(col_name); if (distinct_columns.contains(field->result_name)) { auto unique_name = getUniqueName(field->result_name); required_columns.emplace_back(NameWithAlias(field->result_name, unique_name)); distinct_columns.emplace(unique_name); } else { required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name)); distinct_columns.emplace(field->result_name); } } else if (expr.has_scalar_function()) { const auto & scalar_function = expr.scalar_function(); auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference())); std::vector<String> result_names; if (startsWith(function_signature, "explode:")) actions_dag = parseArrayJoin(header, expr, result_names, actions_dag, true, false); else if (startsWith(function_signature, "posexplode:")) actions_dag = parseArrayJoin(header, expr, result_names, actions_dag, true, true); else if (startsWith(function_signature, "json_tuple:")) actions_dag = parseJsonTuple(header, expr, result_names, actions_dag, true, false); else { result_names.resize(1); actions_dag = parseFunction(header, expr, result_names[0], actions_dag, true); } for (const auto & result_name : result_names) { if (result_name.empty()) continue; if (distinct_columns.contains(result_name)) { auto unique_name = getUniqueName(result_name); required_columns.emplace_back(NameWithAlias(result_name, unique_name)); distinct_columns.emplace(unique_name); } else { required_columns.emplace_back(NameWithAlias(result_name, result_name)); distinct_columns.emplace(result_name); } } } else if (expr.has_cast() || expr.has_if_then() || expr.has_literal()) { const auto * node = parseExpression(actions_dag, expr); actions_dag->addOrReplaceInOutputs(*node); if (distinct_columns.contains(node->result_name)) { auto unique_name = getUniqueName(node->result_name); required_columns.emplace_back(NameWithAlias(node->result_name, unique_name)); distinct_columns.emplace(unique_name); } else { required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name)); distinct_columns.emplace(node->result_name); } } else throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case())); } actions_dag->project(required_columns); return actions_dag; } std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool null_on_overflow) { std::string ch_function_name; UInt32 precision = decimal.precision(); if (precision <= DataTypeDecimal32::maxPrecision()) ch_function_name = "toDecimal32"; else if (precision <= DataTypeDecimal64::maxPrecision()) ch_function_name = "toDecimal64"; else if (precision <= DataTypeDecimal128::maxPrecision()) ch_function_name = "toDecimal128"; else throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); if (null_on_overflow) ch_function_name = ch_function_name + "OrNull"; return ch_function_name; } bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) { return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with("iterator"); } bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel) { assert(rel.has_advanced_extension()); bool is_read_from_merge_tree; google::protobuf::StringValue optimization; optimization.ParseFromString(rel.advanced_extension().optimization().value()); ReadBufferFromString in(optimization.value()); assertString("isMergeTree=", in); readBoolText(is_read_from_merge_tree, in); assertChar('\n', in); return is_read_from_merge_tree; } QueryPlanStepPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrait::ReadRel & rel) { auto header = TypeParser::buildBlockFromNamedStruct(rel.base_schema()); substrait::ReadRel::LocalFiles local_files; if (rel.has_local_files()) local_files = rel.local_files(); else local_files = parseLocalFiles(split_infos.at(nextSplitInfoIndex())); auto source = std::make_shared<SubstraitFileSource>(context, header, local_files); auto source_pipe = Pipe(source); auto source_step = std::make_unique<SubstraitFileSourceStep>(context, std::move(source_pipe), "substrait local files"); source_step->setStepDescription("read local files"); if (rel.has_filter()) { const ActionsDAGPtr actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); const ActionsDAG::Node * filter_node = parseExpression(actions_dag, rel.filter()); actions_dag->addOrReplaceInOutputs(*filter_node); source_step->addFilter(actions_dag, filter_node); } return source_step; } QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait::ReadRel & rel) { assert(rel.has_local_files()); assert(rel.local_files().items().size() == 1); auto iter = rel.local_files().items().at(0).uri_file(); auto pos = iter.find(':'); auto iter_index = std::stoi(iter.substr(pos + 1, iter.size())); auto source = std::make_shared<SourceFromJavaIter>( context, TypeParser::buildBlockFromNamedStruct(rel.base_schema()), input_iters[iter_index], materialize_inputs[iter_index]); QueryPlanStepPtr source_step = std::make_unique<ReadFromPreparedSource>(Pipe(source)); source_step->setStepDescription("Read From Java Iter"); return source_step; } IQueryPlanStep * SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns) { if (columns.empty()) return nullptr; auto remove_nullable_actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(plan.getCurrentDataStream().header)); removeNullable(columns, remove_nullable_actions_dag); auto expression_step = std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), remove_nullable_actions_dag); expression_step->setStepDescription("Remove nullable properties"); auto * step_ptr = expression_step.get(); plan.addStep(std::move(expression_step)); return step_ptr; } PrewhereInfoPtr SerializedPlanParser::parsePreWhereInfo(const substrait::Expression & rel, Block & input) { auto prewhere_info = std::make_shared<PrewhereInfo>(); prewhere_info->prewhere_actions = std::make_shared<ActionsDAG>(input.getNamesAndTypesList()); std::string filter_name; // for in function if (rel.has_singular_or_list()) { const auto * in_node = parseExpression(prewhere_info->prewhere_actions, rel); prewhere_info->prewhere_actions->addOrReplaceInOutputs(*in_node); filter_name = in_node->result_name; } else { parseFunctionWithDAG(rel, filter_name, prewhere_info->prewhere_actions, true); } prewhere_info->prewhere_column_name = filter_name; prewhere_info->need_filter = true; prewhere_info->remove_prewhere_column = true; auto cols = prewhere_info->prewhere_actions->getRequiredColumnsNames(); // Keep it the same as the input. prewhere_info->prewhere_actions->removeUnusedActions(Names{filter_name}, false, true); prewhere_info->prewhere_actions->projectInput(false); for (const auto & name : input.getNames()) { prewhere_info->prewhere_actions->tryRestoreColumn(name); } return prewhere_info; } DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) { return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type); } DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) { if (nullable && !nested_type->isNullable()) return std::make_shared<DataTypeNullable>(nested_type); else return nested_type; } QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan) { logDebugMessage(*plan, "substrait plan"); parseExtensions(plan->extensions()); if (plan->relations_size() == 1) { auto root_rel = plan->relations().at(0); if (!root_rel.has_root()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!"); } std::list<const substrait::Rel *> rel_stack; auto query_plan = parseOp(root_rel.root().input(), rel_stack); if (root_rel.root().names_size()) { ActionsDAGPtr actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(query_plan->getCurrentDataStream().header)); NamesWithAliases aliases; auto cols = query_plan->getCurrentDataStream().header.getNamesAndTypesList(); if (cols.getNames().size() != static_cast<size_t>(root_rel.root().names_size())) { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Missmatch result columns size."); } for (int i = 0; i < static_cast<int>(cols.getNames().size()); i++) { aliases.emplace_back(NameWithAlias(cols.getNames()[i], root_rel.root().names(i))); } actions_dag->project(aliases); auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag); expression_step->setStepDescription("Rename Output"); query_plan->addStep(std::move(expression_step)); } // fixes: issue-1874, to keep the nullability as expected. const auto & output_schema = root_rel.root().output_schema(); if (output_schema.types_size()) { auto original_header = query_plan->getCurrentDataStream().header; const auto & original_cols = original_header.getColumnsWithTypeAndName(); if (static_cast<size_t>(output_schema.types_size()) != original_cols.size()) { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mismatch output schema"); } bool need_final_project = false; DB::ColumnsWithTypeAndName final_cols; for (int i = 0; i < output_schema.types_size(); ++i) { const auto & col = original_cols[i]; auto type = TypeParser::parseType(output_schema.types(i)); // At present, we only check nullable mismatch. // intermediate aggregate data is special, no check here. if (type->isNullable() != col.type->isNullable() && !typeid_cast<const DB::DataTypeAggregateFunction *>(col.type.get())) { if (type->isNullable()) { final_cols.emplace_back(type->createColumn(), std::make_shared<DB::DataTypeNullable>(col.type), col.name); } else { final_cols.emplace_back(type->createColumn(), DB::removeNullable(col.type), col.name); } need_final_project = true; } else { final_cols.push_back(col); } } if (need_final_project) { ActionsDAGPtr final_project = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position); QueryPlanStepPtr final_project_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), final_project); final_project_step->setStepDescription("Project for output schema"); query_plan->addStep(std::move(final_project_step)); } } return query_plan; } else { throw Exception(ErrorCodes::BAD_ARGUMENTS, "too many relations found"); } } QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack) { QueryPlanPtr query_plan; std::vector<IQueryPlanStep *> steps; switch (rel.rel_type_case()) { case substrait::Rel::RelTypeCase::kFetch: { rel_stack.push_back(&rel); const auto & limit = rel.fetch(); query_plan = parseOp(limit.input(), rel_stack); rel_stack.pop_back(); auto limit_step = std::make_unique<LimitStep>(query_plan->getCurrentDataStream(), limit.count(), limit.offset()); limit_step->setStepDescription("LIMIT"); steps.emplace_back(limit_step.get()); query_plan->addStep(std::move(limit_step)); break; } case substrait::Rel::RelTypeCase::kRead: { const auto & read = rel.read(); // TODO: We still maintain the old logic of parsing LocalFiles or ExtensionTable in RealRel // to be compatiable with some suites about metrics. // Remove this compatiability in later and then only java iter has local files in ReadRel. if (read.has_local_files() || (!read.has_extension_table() && !isReadFromMergeTree(read))) { assert(rel.has_base_schema()); QueryPlanStepPtr step; if (isReadRelFromJava(read)) step = parseReadRealWithJavaIter(read); else step = parseReadRealWithLocalFile(read); query_plan = std::make_unique<QueryPlan>(); steps.emplace_back(step.get()); query_plan->addStep(std::move(step)); // Add a buffer after source, it try to preload data from source and reduce the // waiting time of downstream nodes. if (context->getSettingsRef().max_threads > 1) { auto buffer_step = std::make_unique<BlocksBufferPoolStep>(query_plan->getCurrentDataStream()); steps.emplace_back(buffer_step.get()); query_plan->addStep(std::move(buffer_step)); } } else { substrait::ReadRel::ExtensionTable extension_table; if (read.has_extension_table()) extension_table = read.extension_table(); else extension_table = parseExtensionTable(split_infos.at(nextSplitInfoIndex())); MergeTreeRelParser mergeTreeParser(this, context, query_context, global_context); std::list<const substrait::Rel *> stack; query_plan = mergeTreeParser.parseReadRel(std::make_unique<QueryPlan>(), read, extension_table, stack); steps = mergeTreeParser.getSteps(); } break; } case substrait::Rel::RelTypeCase::kFilter: case substrait::Rel::RelTypeCase::kGenerate: case substrait::Rel::RelTypeCase::kProject: case substrait::Rel::RelTypeCase::kAggregate: case substrait::Rel::RelTypeCase::kSort: case substrait::Rel::RelTypeCase::kWindow: case substrait::Rel::RelTypeCase::kJoin: case substrait::Rel::RelTypeCase::kExpand: { auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this); query_plan = op_parser->parseOp(rel, rel_stack); auto parser_steps = op_parser->getSteps(); steps.insert(steps.end(), parser_steps.begin(), parser_steps.end()); break; } default: throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString()); } if (!context->getSettingsRef().query_plan_enable_optimizations) { if (rel.rel_type_case() == substrait::Rel::RelTypeCase::kRead) { size_t id = metrics.empty() ? 0 : metrics.back()->getId() + 1; metrics.emplace_back(std::make_shared<RelMetric>(id, String(magic_enum::enum_name(rel.rel_type_case())), steps)); } else metrics = {std::make_shared<RelMetric>(String(magic_enum::enum_name(rel.rel_type_case())), metrics, steps)}; } return query_plan; } NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & header) { NamesAndTypesList types; for (const auto & name : header.getNames()) { const auto * column = header.findByName(name); types.push_back(NameAndTypePair(column->name, column->type)); } return types; } std::string SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function) { auto args = function.arguments(); auto pos = function_signature.find(':'); auto function_name = function_signature.substr(0, pos); if (!SCALAR_FUNCTIONS.contains(function_name)) throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", function_name); std::string ch_function_name; if (function_name == "trim") ch_function_name = args.size() == 1 ? "trimBoth" : "trimBothSpark"; else if (function_name == "ltrim") ch_function_name = args.size() == 1 ? "trimLeft" : "trimLeftSpark"; else if (function_name == "rtrim") ch_function_name = args.size() == 1 ? "trimRight" : "trimRightSpark"; else if (function_name == "extract") { if (args.size() != 2) throw Exception( ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); // Get the first arg: field const auto & extract_field = args.at(0); if (extract_field.value().has_literal()) { const auto & field_value = extract_field.value().literal().string(); if (field_value == "YEAR") ch_function_name = "toYear"; // spark: extract(YEAR FROM) or year else if (field_value == "YEAR_OF_WEEK") ch_function_name = "toISOYear"; // spark: extract(YEAROFWEEK FROM) else if (field_value == "QUARTER") ch_function_name = "toQuarter"; // spark: extract(QUARTER FROM) or quarter else if (field_value == "MONTH") ch_function_name = "toMonth"; // spark: extract(MONTH FROM) or month else if (field_value == "WEEK_OF_YEAR") ch_function_name = "toISOWeek"; // spark: extract(WEEK FROM) or weekofyear else if (field_value == "WEEK_DAY") /// Spark WeekDay(date) (0 = Monday, 1 = Tuesday, ..., 6 = Sunday) /// Substrait: extract(WEEK_DAY from date) /// CH: toDayOfWeek(date, 1) ch_function_name = "toDayOfWeek"; else if (field_value == "DAY_OF_WEEK") /// Spark: DayOfWeek(date) (1 = Sunday, 2 = Monday, ..., 7 = Saturday) /// Substrait: extract(DAY_OF_WEEK from date) /// CH: toDayOfWeek(date, 3) /// DAYOFWEEK is alias of function toDayOfWeek. /// This trick is to distinguish between extract fields DAY_OF_WEEK and WEEK_DAY in latter codes ch_function_name = "DAYOFWEEK"; else if (field_value == "DAY") ch_function_name = "toDayOfMonth"; // spark: extract(DAY FROM) or dayofmonth else if (field_value == "DAY_OF_YEAR") ch_function_name = "toDayOfYear"; // spark: extract(DOY FROM) or dayofyear else if (field_value == "HOUR") ch_function_name = "toHour"; // spark: extract(HOUR FROM) or hour else if (field_value == "MINUTE") ch_function_name = "toMinute"; // spark: extract(MINUTE FROM) or minute else if (field_value == "SECOND") ch_function_name = "toSecond"; // spark: extract(SECOND FROM) or secondwithfraction else throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong."); } else throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong."); } else if (function_name == "sha2") { if (args.size() != 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Spark function sha2 requires two args, function:{}", function.ShortDebugString()); const auto & bit_length = args.at(1); if (!bit_length.value().has_literal() || !bit_length.value().literal().has_i32()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of spark sha2 function is wrong."); const auto & bit_length_value = bit_length.value().literal().i32(); if (bit_length_value == 224) ch_function_name = "SHA224"; else if (bit_length_value == 256 || bit_length_value == 0) ch_function_name = "SHA256"; else if (bit_length_value == 384) ch_function_name = "SHA384"; else if (bit_length_value == 512) ch_function_name = "SHA512"; else throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of spark sha2 function is wrong, value:{}", bit_length_value); } else if (function_name == "check_overflow") { if (args.size() < 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); ch_function_name = SCALAR_FUNCTIONS.at(function_name); auto null_on_overflow = args.at(1).value().literal().boolean(); if (null_on_overflow) ch_function_name = ch_function_name + "OrNull"; } else if (function_name == "make_decimal") { if (args.size() < 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args."); ch_function_name = SCALAR_FUNCTIONS.at(function_name); auto null_on_overflow = args.at(1).value().literal().boolean(); if (null_on_overflow) ch_function_name = ch_function_name + "OrNull"; } else if (function_name == "char_length") { /// In Spark /// char_length returns the number of bytes when input is binary type, corresponding to CH length function /// char_length returns the number of characters when input is string type, corresponding to CH char_length function ch_function_name = SCALAR_FUNCTIONS.at(function_name); if (function_signature.find("vbin") != std::string::npos) ch_function_name = "length"; } else if (function_name == "reverse") { if (function.output_type().has_list()) ch_function_name = "arrayReverse"; else ch_function_name = "reverseUTF8"; } else if (function_name == "concat") { /// 1. ConcatOverloadResolver cannot build arrayConcat for Nullable(Array) type which causes failures when using functions like concat(split()). /// So we use arrayConcat directly if the output type is array. /// 2. CH ConcatImpl can only accept at least 2 arguments, but Spark concat can accept 1 argument, like concat('a') /// in such case we use identity function if (function.output_type().has_list()) ch_function_name = "arrayConcat"; else if (args.size() == 1) ch_function_name = "identity"; else ch_function_name = "concat"; } else ch_function_name = SCALAR_FUNCTIONS.at(function_name); return ch_function_name; } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( const substrait::Expression & rel, std::vector<String> & result_names, DB::ActionsDAGPtr actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); auto function_name = getFunctionName(function_signature, scalar_function); if (function_name != "arrayJoin") throw Exception( ErrorCodes::LOGICAL_ERROR, "Function parseArrayJoinWithDAG should only process arrayJoin function, but input is {}", rel.ShortDebugString()); /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 if (scalar_function.arguments_size() != 1) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 but is {}", scalar_function.arguments_size()); ActionsDAG::NodeRawConstPtrs args; parseFunctionArguments(actions_dag, args, function_name, scalar_function); auto arg_type = DB::removeNullable(args[0]->result_type); /// array() or map() const auto * empty_map_or_array_node = addColumn(actions_dag, DB::removeNullable(args[0]->result_type), isMap(arg_type) ? Field(Map()) : Field(Array())); /// ifNull(args[0], array() or map()) const auto * if_null_node = toFunctionNode(actions_dag, "ifNull", {args[0], empty_map_or_array_node}); /// assumeNotNull(ifNull(args[0], array() or map())) const auto * arg_not_null = toFunctionNode(actions_dag, "assumeNotNull", {if_null_node}); /// Wrap with materalize function to make sure column input to ARRAY JOIN STEP is materaized arg_not_null = &actions_dag->materializeNode(*arg_not_null); /// arrayJoin(arg_not_null) /// Note: Make sure result_name keep the same after applying arrayJoin function, which makes it much easier to transform arrayJoin function to ARRAY JOIN STEP /// Otherwise an alias node must be appended after ARRAY JOIN STEP, which is not a graceful implementation. auto array_join_name = arg_not_null->result_name; const auto * array_join_node = &actions_dag->addArrayJoin(*arg_not_null, array_join_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared<DataTypeUInt32>(); auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name); }; /// Special process to keep compatiable with Spark WhichDataType which(arg_type.get()); if (!position) { /// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map) if (which.isMap()) { /// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value" /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type. /// So we must wrap arrayJoin with sparkTupleElement function for compatiability. /// arrayJoin(arg_not_null).1 const auto * key_node = add_tuple_element(array_join_node, 1); /// arrayJoin(arg_not_null).2 const auto * val_node = add_tuple_element(array_join_node, 2); result_names.push_back(key_node->result_name); result_names.push_back(val_node->result_name); if (keep_result) { actions_dag->addOrReplaceInOutputs(*key_node); actions_dag->addOrReplaceInOutputs(*val_node); } return {key_node, val_node}; } else if (which.isArray()) { result_names.push_back(array_join_name); if (keep_result) actions_dag->addOrReplaceInOutputs(*array_join_node); return {array_join_node}; } else throw Exception( ErrorCodes::BAD_ARGUMENTS, "Argument type of arrayJoin converted from explode should be Array or Map but is {}", arg_type->getName()); } else { /// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map) if (which.isMap()) { /// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value) /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type. /// So we must wrap arrayJoin with sparkTupleElement function for compatiability. /// pos = arrayJoin(arg_not_null).1 const auto * pos_node = add_tuple_element(array_join_node, 1); /// col = arrayJoin(arg_not_null).2 or (key, value) = arrayJoin(arg_not_null).2 const auto * item_node = add_tuple_element(array_join_node, 2); /// It is a tricky but efficient way to get the original type of argument type in posexplode if (endsWith(args[0]->result_name, "type_hint:map")) { /// key = arrayJoin(arg_not_null).2.1 const auto * item_key_node = add_tuple_element(item_node, 1); /// value = arrayJoin(arg_not_null).2.2 const auto * item_value_node = add_tuple_element(item_node, 2); result_names.push_back(pos_node->result_name); result_names.push_back(item_key_node->result_name); result_names.push_back(item_value_node->result_name); if (keep_result) { actions_dag->addOrReplaceInOutputs(*pos_node); actions_dag->addOrReplaceInOutputs(*item_key_node); actions_dag->addOrReplaceInOutputs(*item_value_node); } return {pos_node, item_key_node, item_value_node}; } else if (endsWith(args[0]->result_name, "type_hint:array")) { /// col = arrayJoin(arg_not_null).2 result_names.push_back(pos_node->result_name); result_names.push_back(item_node->result_name); if (keep_result) { actions_dag->addOrReplaceInOutputs(*pos_node); actions_dag->addOrReplaceInOutputs(*item_node); } return {pos_node, item_node}; } else throw Exception( ErrorCodes::BAD_ARGUMENTS, "The raw input of arrayJoin converted from posexplode should be Array or Map type"); } else throw Exception( ErrorCodes::BAD_ARGUMENTS, "Argument type of arrayJoin converted from posexplode should be Map but is {}", arg_type->getName()); } } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference())); /// If the substrait function name is registered in FunctionParserFactory, use it to parse the function, and return result directly auto pos = function_signature.find(':'); auto func_name = function_signature.substr(0, pos); auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this); if (func_parser) { LOG_DEBUG( &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); result_name = result_node->result_name; return result_node; } auto ch_func_name = getFunctionName(function_signature, scalar_function); ActionsDAG::NodeRawConstPtrs args; parseFunctionArguments(actions_dag, args, ch_func_name, scalar_function); /// If the first argument of function formatDateTimeInJodaSyntax is integer, replace formatDateTimeInJodaSyntax with fromUnixTimestampInJodaSyntax /// to avoid exception if (ch_func_name == "formatDateTimeInJodaSyntax") { if (args.size() > 1 && isInteger(DB::removeNullable(args[0]->result_type))) ch_func_name = "fromUnixTimestampInJodaSyntax"; } const ActionsDAG::Node * result_node; if (ch_func_name == "alias") { result_name = args[0]->result_name; actions_dag->addOrReplaceInOutputs(*args[0]); result_node = &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name); } else { if (ch_func_name == "splitByRegexp") { if (args.size() >= 2) { /// In Spark: split(str, regex [, limit] ) /// In CH: splitByRegexp(regexp, str [, limit]) std::swap(args[0], args[1]); } } if (function_signature.find("check_overflow:", 0) != function_signature.npos) { if (scalar_function.arguments().size() < 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args."); ActionsDAG::NodeRawConstPtrs new_args; new_args.reserve(3); new_args.emplace_back(args[0]); UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared<DataTypeUInt32>(); new_args.emplace_back(&actions_dag->addColumn( ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); new_args.emplace_back(&actions_dag->addColumn( ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } else if (startsWith(function_signature, "make_decimal:")) { if (scalar_function.arguments().size() < 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args."); ActionsDAG::NodeRawConstPtrs new_args; new_args.reserve(3); new_args.emplace_back(args[0]); UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared<DataTypeUInt32>(); new_args.emplace_back(&actions_dag->addColumn( ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); new_args.emplace_back(&actions_dag->addColumn( ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } bool converted_decimal_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, args, scalar_function); auto function_builder = FunctionFactory::instance().get(ch_func_name, context); std::string args_name = join(args, ','); result_name = ch_func_name + "(" + args_name + ")"; const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); result_node = function_node; if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args) { auto result_type = TypeParser::parseType(rel.scalar_function().output_type()); if (isDecimalOrNullableDecimal(result_type)) { result_node = ActionsDAGUtil::convertNodeType( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() : local_engine::removeNullable(result_type)->getName(), function_node->result_name, DB::CastType::accurateOrNull); } else { result_node = ActionsDAGUtil::convertNodeType( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() : local_engine::removeNullable(result_type)->getName(), function_node->result_name); } } if (ch_func_name == "JSON_VALUE") result_node->function->setResolver(function_builder); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); } return result_node; } bool SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs( ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun) { auto function_signature = function_mapping.at(std::to_string(arithmeticFun.function_reference())); auto pos = function_signature.find(':'); auto func_name = function_signature.substr(0, pos); if (func_name == "divide" || func_name == "multiply" || func_name == "plus" || func_name == "minus") { /// for divide/plus/minus, we need to convert first arg to result precision and scale /// for multiply, we need to convert first arg to result precision, but keep scale if (isDecimalOrNullableDecimal(args[0]->result_type) && isDecimalOrNullableDecimal(args[1]->result_type)) { UInt32 p1 = getDecimalPrecision(*DB::removeNullable(args[0]->result_type)); UInt32 s1 = getDecimalScale(*DB::removeNullable(args[0]->result_type)); UInt32 p2 = getDecimalPrecision(*DB::removeNullable(args[1]->result_type)); UInt32 s2 = getDecimalScale(*DB::removeNullable(args[1]->result_type)); UInt32 precision; UInt32 scale; if (func_name == "plus" || func_name == "minus") { scale = s1; precision = scale + std::max(p1 - s1, p2 - s2) + 1; } else if (func_name == "divide") { scale = std::max(static_cast<UInt32>(6), s1 + p2 + 1); precision = p1 - s1 + s2 + scale; } else // multiply { scale = s1; precision = p1 + p2 + 1; } UInt32 maxPrecision = DataTypeDecimal256::maxPrecision(); UInt32 maxScale = DataTypeDecimal128::maxPrecision(); precision = std::min(precision, maxPrecision); scale = std::min(scale, maxScale); ActionsDAG::NodeRawConstPtrs new_args; new_args.reserve(args.size()); ActionsDAG::NodeRawConstPtrs cast_args; cast_args.reserve(2); cast_args.emplace_back(args[0]); DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale); ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type); String type_name = ch_type->getName(); DataTypePtr str_type = std::make_shared<DataTypeString>(); const ActionsDAG::Node * type_node = &actions_dag->addColumn( ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name))); cast_args.emplace_back(type_node); const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args); actions_dag->addOrReplaceInOutputs(*cast_node); new_args.emplace_back(cast_node); new_args.emplace_back(args[1]); args = std::move(new_args); return true; } } return false; } void SerializedPlanParser::parseFunctionArguments( DB::ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, std::string & function_name, const substrait::Expression_ScalarFunction & scalar_function) { auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference())); const auto & args = scalar_function.arguments(); parsed_args.reserve(args.size()); // Some functions need to be handled specially. if (function_name == "JSONExtract") { parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); auto data_type = TypeParser::parseType(scalar_function.output_type()); parsed_args.emplace_back(addColumn(actions_dag, std::make_shared<DB::DataTypeString>(), data_type->getName())); } else if (function_name == "sparkTupleElement" || function_name == "tupleElement") { parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); if (!args[1].value().has_literal()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal"); auto [data_type, field] = parseLiteral(args[1].value().literal()); if (data_type->getTypeId() != DB::TypeIndex::Int32) throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32"); // tuple indecies start from 1, in spark, start from 0 Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1); const auto * index_node = addColumn(actions_dag, std::make_shared<DB::DataTypeUInt32>(), field_index); parsed_args.emplace_back(index_node); } else if (function_name == "tuple") { // Arguments in the format, (<field name>, <value expression>[, <field name>, <value expression> ...]) // We don't need to care the field names here. for (int index = 1; index < args.size(); index += 2) parseFunctionArgument(actions_dag, parsed_args, function_name, args[index]); } else if (function_name == "repeat") { // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait // which must be a positive value into unsigned integer here. parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, function_name, args[1]); DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>()); repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName()); parsed_args.emplace_back(repeat_times_node); } else if (function_name == "isNaN") { // the result of isNaN(NULL) is NULL in CH, but false in Spark const DB::ActionsDAG::Node * arg_node = nullptr; if (args[0].value().has_cast()) { arg_node = parseExpression(actions_dag, args[0].value().cast().input()); const auto * res_type = arg_node->result_type.get(); if (res_type->isNullable()) { res_type = typeid_cast<const DB::DataTypeNullable *>(res_type)->getNestedType().get(); } if (isString(*res_type)) { DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node}; arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args); } else { arg_node = parseFunctionArgument(actions_dag, function_name, args[0]); } } else { arg_node = parseFunctionArgument(actions_dag, function_name, args[0]); } DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)}; parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args)); } else if (function_name == "positionUTF8Spark") { if (args.size() >= 2) { // In Spark: position(substr, str, Int32) // In CH: position(str, subtr, UInt32) parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]); parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); } if (args.size() >= 3) { // add cast: cast(start_pos as UInt32) const auto * start_pos_node = parseFunctionArgument(actions_dag, function_name, args[2]); DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>()); start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName()); parsed_args.emplace_back(start_pos_node); } } else if (function_name == "space") { // convert space function to repeat const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, "repeat", args[0]); const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), " "); function_name = "repeat"; parsed_args.emplace_back(space_str_node); parsed_args.emplace_back(repeat_times_node); } else if (function_name == "trimBothSpark" || function_name == "trimLeftSpark" || function_name == "trimRightSpark") { /// In substrait, the first arg is srcStr, the second arg is trimStr /// But in CH, the first arg is trimStr, the second arg is srcStr parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]); parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]); } else if (startsWith(function_signature, "extract:")) { /// Skip the first arg of extract in substrait for (int i = 1; i < args.size(); i++) parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]); /// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait if (function_name == "toDayOfWeek" || function_name == "DAYOFWEEK") { UInt8 mode = function_name == "toDayOfWeek" ? 1 : 3; auto mode_type = std::make_shared<DataTypeUInt8>(); ColumnWithTypeAndName mode_col(mode_type->createColumnConst(1, mode), mode_type, getUniqueName(std::to_string(mode))); const auto & mode_node = actions_dag->addColumn(std::move(mode_col)); parsed_args.emplace_back(&mode_node); } } else if (startsWith(function_signature, "sha2:")) { for (int i = 0; i < args.size() - 1; i++) parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]); } else { // Default handle for (const auto & arg : args) parseFunctionArgument(actions_dag, parsed_args, function_name, arg); } } void SerializedPlanParser::parseFunctionArgument( DB::ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, const std::string & function_name, const substrait::FunctionArgument & arg) { parsed_args.emplace_back(parseFunctionArgument(actions_dag, function_name, arg)); } const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( DB::ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) { const DB::ActionsDAG::Node * res; if (arg.value().has_scalar_function()) { std::string arg_name; bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name); parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg); res = &actions_dag->getNodes().back(); } else { res = parseExpression(actions_dag, arg.value()); } return res; } // Convert signed integer index into unsigned integer index std::pair<DB::DataTypePtr, DB::Field> SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field) { // For tupelElement, field index starts from 1, but int substrait plan, it starts from 0. #define UINT_CONVERT(type_ptr, field, type_name) \ if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \ { \ return {std::make_shared<DB::DataTypeU##type_name>(), static_cast<U##type_name>((field).get<type_name>()) + 1}; \ } auto type_id = type->getTypeId(); if (type_id == DB::TypeIndex::UInt8 || type_id == DB::TypeIndex::UInt16 || type_id == DB::TypeIndex::UInt32 || type_id == DB::TypeIndex::UInt64) { return {type, field}; } UINT_CONVERT(type, field, Int8) UINT_CONVERT(type, field, Int16) UINT_CONVERT(type, field, Int32) UINT_CONVERT(type, field, Int64) throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid interger type: {}", type->getName()); #undef UINT_CONVERT } ActionsDAGPtr SerializedPlanParser::parseFunction( const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); parseFunctionWithDAG(rel, result_name, actions_dag, keep_result); return actions_dag; } ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression( const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); if (rel.has_scalar_function()) parseFunctionWithDAG(rel, result_name, actions_dag, keep_result); else { const auto * result_node = parseExpression(actions_dag, rel); result_name = result_node->result_name; } return actions_dag; } ActionsDAGPtr SerializedPlanParser::parseArrayJoin( const Block & input, const substrait::Expression & rel, std::vector<String> & result_names, ActionsDAGPtr actions_dag, bool keep_result, bool position) { if (!actions_dag) actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(input)); parseArrayJoinWithDAG(rel, result_names, actions_dag, keep_result, position); return actions_dag; } ActionsDAGPtr SerializedPlanParser::parseJsonTuple( const Block & input, const substrait::Expression & rel, std::vector<String> & result_names, ActionsDAGPtr actions_dag, bool keep_result, bool) { if (!actions_dag) { actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(input)); } const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); auto function_name = getFunctionName(function_signature, scalar_function); auto args = scalar_function.arguments(); if (args.size() < 2) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "The function json_tuple should has at least 2 arguments."); } auto first_arg = args[0].value(); const DB::ActionsDAG::Node * json_expr_node = parseExpression(actions_dag, first_arg); std::string extract_expr = "Tuple("; for (int i = 1; i < args.size(); i++) { auto arg_value = args[i].value(); if (!arg_value.has_literal()) { throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The arguments of function {} must be string literal", function_name); } DB::Field f = arg_value.literal().string(); std::string s; if (f.tryGet(s)) { extract_expr.append(s).append(" Nullable(String)"); if (i != args.size() - 1) { extract_expr.append(","); } } } extract_expr.append(")"); const DB::ActionsDAG::Node * extract_expr_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), extract_expr); auto json_extract_builder = FunctionFactory::instance().get("JSONExtract", context); auto json_extract_result_name = "JSONExtract(" + json_expr_node->result_name + "," + extract_expr_node->result_name + ")"; const ActionsDAG::Node * json_extract_node = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared<DataTypeUInt32>(); auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name); }; for (int i = 1; i < args.size(); i++) { const ActionsDAG::Node * tuple_node = add_tuple_element(json_extract_node, i); if (keep_result) { actions_dag->addOrReplaceInOutputs(*tuple_node); result_names.push_back(tuple_node->result_name); } } return actions_dag; } const ActionsDAG::Node * SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args) { auto function_builder = DB::FunctionFactory::instance().get(function, context); std::string args_name = join(args, ','); auto result_name = function + "(" + args_name + ")"; const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); return function_node; } std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait::Expression_Literal & literal) { DataTypePtr type; Field field; switch (literal.literal_type_case()) { case substrait::Expression_Literal::kFp64: { type = std::make_shared<DataTypeFloat64>(); field = literal.fp64(); break; } case substrait::Expression_Literal::kFp32: { type = std::make_shared<DataTypeFloat32>(); field = literal.fp32(); break; } case substrait::Expression_Literal::kString: { type = std::make_shared<DataTypeString>(); field = literal.string(); break; } case substrait::Expression_Literal::kBinary: { type = std::make_shared<DataTypeString>(); field = literal.binary(); break; } case substrait::Expression_Literal::kI64: { type = std::make_shared<DataTypeInt64>(); field = literal.i64(); break; } case substrait::Expression_Literal::kI32: { type = std::make_shared<DataTypeInt32>(); field = literal.i32(); break; } case substrait::Expression_Literal::kBoolean: { type = std::make_shared<DataTypeUInt8>(); field = literal.boolean() ? UInt8(1) : UInt8(0); break; } case substrait::Expression_Literal::kI16: { type = std::make_shared<DataTypeInt16>(); field = literal.i16(); break; } case substrait::Expression_Literal::kI8: { type = std::make_shared<DataTypeInt8>(); field = literal.i8(); break; } case substrait::Expression_Literal::kDate: { type = std::make_shared<DataTypeDate32>(); field = literal.date(); break; } case substrait::Expression_Literal::kTimestamp: { type = std::make_shared<DataTypeDateTime64>(6); field = DecimalField<DateTime64>(literal.timestamp(), 6); break; } case substrait::Expression_Literal::kDecimal: { UInt32 precision = literal.decimal().precision(); UInt32 scale = literal.decimal().scale(); const auto & bytes = literal.decimal().value(); if (precision <= DataTypeDecimal32::maxPrecision()) { type = std::make_shared<DataTypeDecimal32>(precision, scale); auto value = *reinterpret_cast<const Int32 *>(bytes.data()); field = DecimalField<Decimal32>(value, scale); } else if (precision <= DataTypeDecimal64::maxPrecision()) { type = std::make_shared<DataTypeDecimal64>(precision, scale); auto value = *reinterpret_cast<const Int64 *>(bytes.data()); field = DecimalField<Decimal64>(value, scale); } else if (precision <= DataTypeDecimal128::maxPrecision()) { type = std::make_shared<DataTypeDecimal128>(precision, scale); String bytes_copy(bytes); auto value = *reinterpret_cast<Decimal128 *>(bytes_copy.data()); field = DecimalField<Decimal128>(value, scale); } else throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision); break; } case substrait::Expression_Literal::kList: { const auto & values = literal.list().values(); if (values.empty()) { type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeNothing>()); field = Array(); break; } DataTypePtr common_type; std::tie(common_type, std::ignore) = parseLiteral(values[0]); size_t list_len = values.size(); Array array(list_len); for (int i = 0; i < static_cast<int>(list_len); ++i) { auto type_and_field = parseLiteral(values[i]); common_type = getLeastSupertype(DataTypes{common_type, type_and_field.first}); array[i] = std::move(type_and_field.second); } type = std::make_shared<DataTypeArray>(common_type); field = std::move(array); break; } case substrait::Expression_Literal::kMap: { const auto & key_values = literal.map().key_values(); if (key_values.empty()) { type = std::make_shared<DataTypeMap>(std::make_shared<DataTypeNothing>(), std::make_shared<DataTypeNothing>()); field = Map(); break; } const auto & first_key_value = key_values[0]; DataTypePtr common_key_type; std::tie(common_key_type, std::ignore) = parseLiteral(first_key_value.key()); DataTypePtr common_value_type; std::tie(common_value_type, std::ignore) = parseLiteral(first_key_value.value()); Map map; map.reserve(key_values.size()); for (const auto & key_value : key_values) { Tuple tuple(2); DataTypePtr key_type; std::tie(key_type, tuple[0]) = parseLiteral(key_value.key()); /// Each key should has the same type if (!common_key_type->equals(*key_type)) throw Exception( ErrorCodes::LOGICAL_ERROR, "Literal map key type mismatch:{} and {}", common_key_type->getName(), key_type->getName()); DataTypePtr value_type; std::tie(value_type, tuple[1]) = parseLiteral(key_value.value()); /// Each value should has least super type for all of them common_value_type = getLeastSupertype(DataTypes{common_value_type, value_type}); map.emplace_back(std::move(tuple)); } type = std::make_shared<DataTypeMap>(common_key_type, common_value_type); field = std::move(map); break; } case substrait::Expression_Literal::kStruct: { const auto & fields = literal.struct_().fields(); DataTypes types; types.reserve(fields.size()); Tuple tuple; tuple.reserve(fields.size()); for (const auto & f : fields) { DataTypePtr field_type; Field field_value; std::tie(field_type, field_value) = parseLiteral(f); types.emplace_back(std::move(field_type)); tuple.emplace_back(std::move(field_value)); } type = std::make_shared<DataTypeTuple>(types); field = std::move(tuple); break; } case substrait::Expression_Literal::kNull: { type = TypeParser::parseType(literal.null()); field = Field{}; break; } default: { throw Exception( ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); } } return std::make_pair(std::move(type), std::move(field)); } const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr actions_dag, const substrait::Expression & rel) { switch (rel.rex_type_case()) { case substrait::Expression::RexTypeCase::kLiteral: { DataTypePtr type; Field field; std::tie(type, field) = parseLiteral(rel.literal()); return addColumn(actions_dag, type, field); } case substrait::Expression::RexTypeCase::kSelection: { if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections"); const auto * field = actions_dag->getInputs()[rel.selection().direct_reference().struct_field().field()]; return actions_dag->tryFindInOutputs(field->result_name); } case substrait::Expression::RexTypeCase::kCast: { if (!rel.cast().has_type() || !rel.cast().has_input()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); DB::ActionsDAG::NodeRawConstPtrs args; const auto & input = rel.cast().input(); args.emplace_back(parseExpression(actions_dag, input)); const auto & substrait_type = rel.cast().type(); const ActionsDAG::Node * function_node = nullptr; if (DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_date()) { function_node = toFunctionNode(actions_dag, "spark_to_date", args); } else if (substrait_type.has_binary()) { // Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x) function_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args); } else { DataTypePtr ch_type = TypeParser::parseType(substrait_type); if(DB::isString(DB::removeNullable(ch_type)) && isDecimalOrNullableDecimal(args[0]->result_type)) { UInt8 scale = getDecimalScale(*DB::removeNullable(args[0]->result_type)); args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale))); function_node = toFunctionNode(actions_dag, "toDecimalString", args); } else { if (isFloat(DB::removeNullable(args[0]->result_type)) && isInt(DB::removeNullable(ch_type))) { String function_name = "sparkCastFloatTo" + DB::removeNullable(ch_type)->getName(); function_node = toFunctionNode(actions_dag, function_name, args); } else { args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeString>(), ch_type->getName())); function_node = toFunctionNode(actions_dag, "CAST", args); } } } actions_dag->addOrReplaceInOutputs(*function_node); return function_node; } case substrait::Expression::RexTypeCase::kIfThen: { const auto & if_then = rel.if_then(); DB::FunctionOverloadResolverPtr function_ptr = nullptr; auto condition_nums = if_then.ifs_size(); if (condition_nums == 1) function_ptr = DB::FunctionFactory::instance().get("if", context); else function_ptr = DB::FunctionFactory::instance().get("multiIf", context); DB::ActionsDAG::NodeRawConstPtrs args; for (int i = 0; i < condition_nums; ++i) { const auto & ifs = if_then.ifs(i); const auto * if_node = parseExpression(actions_dag, ifs.if_()); args.emplace_back(if_node); const auto * then_node = parseExpression(actions_dag, ifs.then()); args.emplace_back(then_node); } const auto * else_node = parseExpression(actions_dag, if_then.else_()); args.emplace_back(else_node); std::string args_name = join(args, ','); std::string result_name; if (condition_nums == 1) result_name = "if(" + args_name + ")"; else result_name = "multiIf(" + args_name + ")"; const auto * function_node = &actions_dag->addFunction(function_ptr, args, result_name); actions_dag->addOrReplaceInOutputs(*function_node); return function_node; } case substrait::Expression::RexTypeCase::kScalarFunction: { std::string result; return parseFunctionWithDAG(rel, result, actions_dag, false); } case substrait::Expression::RexTypeCase::kSingularOrList: { const auto & options = rel.singular_or_list().options(); /// options is empty always return false if (options.empty()) return addColumn(actions_dag, std::make_shared<DataTypeUInt8>(), 0); /// options should be literals if (!options[0].has_literal()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type"); DB::ActionsDAG::NodeRawConstPtrs args; args.emplace_back(parseExpression(actions_dag, rel.singular_or_list().value())); bool nullable = false; int options_len = static_cast<int>(options.size()); for (int i = 0; i < options_len; ++i) { if (!options[i].has_literal()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); if (!nullable) nullable = options[i].literal().has_null(); } DataTypePtr elem_type; std::tie(elem_type, std::ignore) = parseLiteral(options[0].literal()); elem_type = wrapNullableType(nullable, elem_type); MutableColumnPtr elem_column = elem_type->createColumn(); elem_column->reserve(options_len); for (int i = 0; i < options_len; ++i) { auto type_and_field = parseLiteral(options[i].literal()); auto option_type = wrapNullableType(nullable, type_and_field.first); if (!elem_type->equals(*option_type)) throw Exception( ErrorCodes::LOGICAL_ERROR, "SingularOrList options type mismatch:{} and {}", elem_type->getName(), option_type->getName()); elem_column->insert(type_and_field.second); } MutableColumns elem_columns; elem_columns.emplace_back(std::move(elem_column)); auto name = getUniqueName("__set"); Block elem_block; elem_block.insert(ColumnWithTypeAndName(nullptr, elem_type, name)); elem_block.setColumns(std::move(elem_columns)); SizeLimits limit; auto elem_set = std::make_shared<Set>(limit, true, false); elem_set->setHeader(elem_block.getColumnsWithTypeAndName()); elem_set->insertFromBlock(elem_block.getColumnsWithTypeAndName()); elem_set->finishInsert(); auto future_set = std::make_shared<FutureSetFromStorage>(std::move(elem_set)); auto arg = ColumnSet::create(1, std::move(future_set)); args.emplace_back(&actions_dag->addColumn(ColumnWithTypeAndName(std::move(arg), std::make_shared<DataTypeSet>(), name))); const auto * function_node = toFunctionNode(actions_dag, "in", args); actions_dag->addOrReplaceInOutputs(*function_node); if (nullable) { /// if sets has `null` and value not in sets /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920) /// In CH: return `false` /// So we used if(a, b, c) cast `false` to `null` if sets has `null` auto type = wrapNullableType(true, function_node->result_type); DB::ActionsDAG::NodeRawConstPtrs cast_args( {function_node, addColumn(actions_dag, type, true), addColumn(actions_dag, type, Field())}); auto cast = FunctionFactory::instance().get("if", context); function_node = toFunctionNode(actions_dag, "if", cast_args); actions_dag->addOrReplaceInOutputs(*function_node); } return function_node; } default: throw Exception( ErrorCodes::UNKNOWN_TYPE, "Unsupported spark expression type {} : {}", magic_enum::enum_name(rel.rex_type_case()), rel.DebugString()); } } substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(const std::string & split_info) { substrait::ReadRel::ExtensionTable extension_table; google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = extension_table.ParseFromCodedStream(&coded_in); if (!ok) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::ReadRel::ExtensionTable from string failed"); logDebugMessage(extension_table, "extension_table"); return extension_table; } substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std::string & split_info) { substrait::ReadRel::LocalFiles local_files; google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = local_files.ParseFromCodedStream(&coded_in); if (!ok) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::ReadRel::LocalFiles from string failed"); logDebugMessage(local_files, "local_files"); return local_files; } QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) { auto plan_ptr = std::make_unique<substrait::Plan>(); /// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. /// Once this problem occurs, it is difficult to troubleshoot, because the pb of c++ will not provide any valid information google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(plan.data()), static_cast<int>(plan.size())); coded_in.SetRecursionLimit(100000); auto ok = plan_ptr->ParseFromCodedStream(&coded_in); if (!ok) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); auto res = parse(std::move(plan_ptr)); auto * logger = &Poco::Logger::get("SerializedPlanParser"); if (logger->debug()) { auto out = PlanUtil::explainPlan(*res); LOG_DEBUG(logger, "clickhouse plan:\n{}", out); } return res; } QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) { auto plan_ptr = std::make_unique<substrait::Plan>(); auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan.c_str()), plan_ptr.get()); if (!s.ok()) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from json string failed: {}", s.ToString()); return parse(std::move(plan_ptr)); } SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) { } ContextMutablePtr SerializedPlanParser::global_context = nullptr; Context::ConfigurationPtr SerializedPlanParser::config = nullptr; void SerializedPlanParser::collectJoinKeys( const substrait::Expression & condition, std::vector<std::pair<int32_t, int32_t>> & join_keys, int32_t right_key_start) { auto condition_name = getFunctionName( function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); if (condition_name == "and") { collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); collectJoinKeys(condition.scalar_function().arguments(1).value(), join_keys, right_key_start); } else if (condition_name == "equals") { const auto & function = condition.scalar_function(); auto left_key_idx = function.arguments(0).value().selection().direct_reference().struct_field().field(); auto right_key_idx = function.arguments(1).value().selection().direct_reference().struct_field().field() - right_key_start; join_keys.emplace_back(std::pair(left_key_idx, right_key_idx)); } else { throw Exception(ErrorCodes::BAD_ARGUMENTS, "doesn't support condition {}", condition_name); } } ActionsDAGPtr ASTParser::convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast) { NamesAndTypesList aggregation_keys; ColumnNumbersList aggregation_keys_indexes_list; AggregationKeysInfo info(aggregation_keys, aggregation_keys_indexes_list, GroupByKind::NONE); SizeLimits size_limits_for_set; ActionsMatcher::Data visitor_data( context, size_limits_for_set, size_t(0), name_and_types, std::make_shared<ActionsDAG>(name_and_types), std::make_shared<PreparedSets>(), false /* no_subqueries */, false /* no_makeset */, false /* only_consts */, info); ActionsVisitor(visitor_data).visit(ast); return visitor_data.getActions(); } ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & rel) { LOG_DEBUG(&Poco::Logger::get("ASTParser"), "substrait plan:\n{}", rel.DebugString()); if (rel.has_singular_or_list()) return parseArgumentToAST(names, rel); if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); const auto & scalar_function = rel.scalar_function(); auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference())); auto function_name = SerializedPlanParser::getFunctionName(function_signature, scalar_function); ASTs ast_args; parseFunctionArgumentsToAST(names, scalar_function, ast_args); return makeASTFunction(function_name, ast_args); } void ASTParser::parseFunctionArgumentsToAST( const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) { const auto & args = scalar_function.arguments(); for (const auto & arg : args) { if (arg.value().has_scalar_function()) { ast_args.emplace_back(parseToAST(names, arg.value())); } else { ast_args.emplace_back(parseArgumentToAST(names, arg.value())); } } } ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expression & rel) { switch (rel.rex_type_case()) { case substrait::Expression::RexTypeCase::kLiteral: { DataTypePtr type; Field field; std::tie(std::ignore, field) = SerializedPlanParser::parseLiteral(rel.literal()); return std::make_shared<ASTLiteral>(field); } case substrait::Expression::RexTypeCase::kSelection: { if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections"); const auto field = rel.selection().direct_reference().struct_field().field(); return std::make_shared<ASTIdentifier>(names[field]); } case substrait::Expression::RexTypeCase::kCast: { if (!rel.cast().has_type() || !rel.cast().has_input()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); /// Append input to asts ASTs args; args.emplace_back(parseArgumentToAST(names, rel.cast().input())); /// Append destination type to asts const auto & substrait_type = rel.cast().type(); /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x) if (substrait_type.has_binary()) return makeASTFunction("reinterpretAsStringSpark", args); else { DataTypePtr ch_type = TypeParser::parseType(substrait_type); args.emplace_back(std::make_shared<ASTLiteral>(ch_type->getName())); return makeASTFunction("CAST", args); } } case substrait::Expression::RexTypeCase::kIfThen: { const auto & if_then = rel.if_then(); auto condition_nums = if_then.ifs_size(); std::string ch_function_name = condition_nums == 1 ? "if" : "multiIf"; auto function_multi_if = DB::FunctionFactory::instance().get(ch_function_name, context); ASTs args; for (int i = 0; i < condition_nums; ++i) { const auto & ifs = if_then.ifs(i); auto if_node = parseArgumentToAST(names, ifs.if_()); args.emplace_back(if_node); auto then_node = parseArgumentToAST(names, ifs.then()); args.emplace_back(then_node); } auto else_node = parseArgumentToAST(names, if_then.else_()); args.emplace_back(std::move(else_node)); return makeASTFunction(ch_function_name, args); } case substrait::Expression::RexTypeCase::kScalarFunction: { return parseToAST(names, rel); } case substrait::Expression::RexTypeCase::kSingularOrList: { const auto & options = rel.singular_or_list().options(); /// options is empty always return false if (options.empty()) return std::make_shared<ASTLiteral>(0); /// options should be literals if (!options[0].has_literal()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type"); ASTs args; args.emplace_back(parseArgumentToAST(names, rel.singular_or_list().value())); bool nullable = false; size_t options_len = options.size(); ASTs in_args; in_args.reserve(options_len); for (int i = 0; i < static_cast<int>(options_len); ++i) { if (!options[i].has_literal()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!"); if (!nullable) nullable = options[i].literal().has_null(); } auto elem_type_and_field = SerializedPlanParser::parseLiteral(options[0].literal()); DataTypePtr elem_type = wrapNullableType(nullable, elem_type_and_field.first); for (int i = 0; i < static_cast<int>(options_len); ++i) { auto type_and_field = SerializedPlanParser::parseLiteral(options[i].literal()); auto option_type = wrapNullableType(nullable, type_and_field.first); if (!elem_type->equals(*option_type)) throw Exception( ErrorCodes::LOGICAL_ERROR, "SingularOrList options type mismatch:{} and {}", elem_type->getName(), option_type->getName()); in_args.emplace_back(std::make_shared<ASTLiteral>(type_and_field.second)); } auto array_ast = makeASTFunction("array", in_args); args.emplace_back(array_ast); auto ast = makeASTFunction("in", args); if (nullable) { /// if sets has `null` and value not in sets /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920) /// In CH: return `false` /// So we used if(a, b, c) cast `false` to `null` if sets has `null` ast = makeASTFunction("if", ast, std::make_shared<ASTLiteral>(true), std::make_shared<ASTLiteral>(Field())); } return ast; } default: throw Exception( ErrorCodes::UNKNOWN_TYPE, "Join on condition error. Unsupported spark expression type {} : {}", magic_enum::enum_name(rel.rex_type_case()), rel.DebugString()); } } void SerializedPlanParser::removeNullable(const std::set<String> & require_columns, ActionsDAGPtr actions_dag) { for (const auto & item : require_columns) { const auto * require_node = actions_dag->tryFindInOutputs(item); if (require_node) { auto function_builder = FunctionFactory::instance().get("assumeNotNull", context); ActionsDAG::NodeRawConstPtrs args = {require_node}; const auto & node = actions_dag->addFunction(function_builder, args, item); actions_dag->addOrReplaceInOutputs(node); } } } void SerializedPlanParser::wrapNullable( const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names) { for (const auto & item : columns) { ActionsDAG::NodeRawConstPtrs args; args.emplace_back(&actions_dag->findInOutputs(item)); const auto * node = toFunctionNode(actions_dag, "toNullable", args); actions_dag->addOrReplaceInOutputs(*node); nullable_measure_names[item] = node->result_name; } } SharedContextHolder SerializedPlanParser::shared_context; LocalExecutor::~LocalExecutor() { if (context->getConfigRef().getBool("dump_pipeline", false)) LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Dump pipeline:\n{}", dumpPipeline()); if (spark_buffer) { ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); spark_buffer.reset(); } } void LocalExecutor::execute(QueryPlanPtr query_plan) { Stopwatch stopwatch; const Settings & settings = context->getSettingsRef(); current_query_plan = std::move(query_plan); auto * logger = &Poco::Logger::get("LocalExecutor"); DB::QueryPriorities priorities; auto query_status = std::make_shared<DB::QueryStatus>( context, "", context->getClientInfo(), priorities.insert(static_cast<int>(settings.priority)), DB::CurrentThread::getGroup(), DB::IAST::QueryKind::Select, settings, 0); QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; auto pipeline_builder = current_query_plan->buildQueryPipeline( optimization_settings, BuildQueryPipelineSettings{ .actions_settings = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}, .process_list_element = query_status}); LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan)); query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(query_pipeline)); auto t_pipeline = stopwatch.elapsedMicroseconds(); executor = std::make_unique<PullingPipelineExecutor>(query_pipeline); auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline; stopwatch.stop(); LOG_INFO( logger, "build pipeline {} ms; create executor {} ms;", t_pipeline / 1000.0, t_executor / 1000.0); header = current_query_plan->getCurrentDataStream().header.cloneEmpty(); ch_column_to_spark_row = std::make_unique<CHColumnToSparkRow>(); } std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(Block & block) { return ch_column_to_spark_row->convertCHColumnToSparkRow(block); } bool LocalExecutor::hasNext() { bool has_next; try { size_t columns = currentBlock().columns(); if (columns == 0 || isConsumed()) { auto empty_block = header.cloneEmpty(); setCurrentBlock(empty_block); has_next = executor->pull(currentBlock()); produce(); } else { has_next = true; } } catch (DB::Exception & e) { LOG_ERROR( &Poco::Logger::get("LocalExecutor"), "LocalExecutor run query plan failed with message: {}. Plan Explained: \n{}", e.message(), PlanUtil::explainPlan(*current_query_plan)); throw; } return has_next; } SparkRowInfoPtr LocalExecutor::next() { checkNextValid(); SparkRowInfoPtr row_info = writeBlockToSparkRow(currentBlock()); consume(); if (spark_buffer) { ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); spark_buffer.reset(); } spark_buffer = std::make_unique<SparkBuffer>(); spark_buffer->address = row_info->getBufferAddress(); spark_buffer->size = row_info->getTotalBytes(); return row_info; } Block * LocalExecutor::nextColumnar() { checkNextValid(); Block * columnar_batch; if (currentBlock().columns() > 0) { columnar_batch = &currentBlock(); } else { auto empty_block = header.cloneEmpty(); setCurrentBlock(empty_block); columnar_batch = &currentBlock(); } consume(); return columnar_batch; } Block & LocalExecutor::getHeader() { return header; } LocalExecutor::LocalExecutor(QueryContext & _query_context, ContextPtr context_) : query_context(_query_context), context(context_) { } std::string LocalExecutor::dumpPipeline() { const auto & processors = query_pipeline.getProcessors(); for (auto & processor : processors) { DB::WriteBufferFromOwnString buffer; auto data_stats = processor->getProcessorDataStats(); buffer << "("; buffer << "\nexcution time: " << processor->getElapsedUs() << " us."; buffer << "\ninput wait time: " << processor->getInputWaitElapsedUs() << " us."; buffer << "\noutput wait time: " << processor->getOutputWaitElapsedUs() << " us."; buffer << "\ninput rows: " << data_stats.input_rows; buffer << "\ninput bytes: " << data_stats.input_bytes; buffer << "\noutput rows: " << data_stats.output_rows; buffer << "\noutput bytes: " << data_stats.output_bytes; buffer << ")"; processor->setDescription(buffer.str()); } DB::WriteBufferFromOwnString out; DB::printPipeline(processors, out); return out.str(); } NonNullableColumnsResolver::NonNullableColumnsResolver( const DB::Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) : header(header_), parser(parser_), cond_rel(cond_rel_) { } // make it simple at present, if the condition contains or, return empty for both side. std::set<String> NonNullableColumnsResolver::resolve() { collected_columns.clear(); visit(cond_rel); return collected_columns; } // TODO: make it the same as spark, it's too simple at present. void NonNullableColumnsResolver::visit(const substrait::Expression & expr) { if (!expr.has_scalar_function()) return; const auto & scalar_function = expr.scalar_function(); auto function_signature = parser.function_mapping.at(std::to_string(scalar_function.function_reference())); auto function_name = safeGetFunctionName(function_signature, scalar_function); // Only some special functions are used to judge whether the column is non-nullable. if (function_name == "and") { visit(scalar_function.arguments(0).value()); visit(scalar_function.arguments(1).value()); } else if (function_name == "greaterOrEquals" || function_name == "greater") { // If it's the case, a > x, what ever x or a is, a and x are non-nullable. // a or x may be a column, or a simple expression like plus etc. visitNonNullable(scalar_function.arguments(0).value()); visitNonNullable(scalar_function.arguments(1).value()); } else if (function_name == "lessOrEquals" || function_name == "less") { // same as gt, gte. visitNonNullable(scalar_function.arguments(0).value()); visitNonNullable(scalar_function.arguments(1).value()); } else if (function_name == "isNotNull") { visitNonNullable(scalar_function.arguments(0).value()); } // else do nothing } void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & expr) { if (expr.has_scalar_function()) { const auto & scalar_function = expr.scalar_function(); auto function_signature = parser.function_mapping.at(std::to_string(scalar_function.function_reference())); auto function_name = safeGetFunctionName(function_signature, scalar_function); if (function_name == "plus" || function_name == "minus" || function_name == "multiply" || function_name == "divide") { visitNonNullable(scalar_function.arguments(0).value()); visitNonNullable(scalar_function.arguments(1).value()); } } else if (expr.has_selection()) { const auto & selection = expr.selection(); auto column_pos = selection.direct_reference().struct_field().field(); auto column_name = header.getByPosition(column_pos).name; collected_columns.insert(column_name); } // else, do nothing. } std::string NonNullableColumnsResolver::safeGetFunctionName( const std::string & function_signature, const substrait::Expression_ScalarFunction & function) { try { return parser.getFunctionName(function_signature, function); } catch (const Exception &) { return ""; } } }