/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace local_engine { static const std::map SCALAR_FUNCTIONS = {{"is_not_null", "isNotNull"}, {"is_null", "isNull"}, {"gte", "greaterOrEquals"}, {"gt", "greater"}, {"lte", "lessOrEquals"}, {"lt", "less"}, {"equal", "equals"}, {"and", "and"}, {"or", "or"}, {"not", "not"}, {"xor", "xor"}, {"extract", ""}, {"cast", "CAST"}, {"alias", "alias"}, /// datetime functions {"get_timestamp", "parseDateTimeInJodaSyntaxOrNull"}, // for spark function: to_date/to_timestamp {"quarter", "toQuarter"}, {"to_unix_timestamp", "parseDateTimeInJodaSyntaxOrNull"}, // {"unix_timestamp", "toUnixTimestamp"}, {"date_format", "formatDateTimeInJodaSyntax"}, /// arithmetic functions {"subtract", "minus"}, {"multiply", "multiply"}, {"add", "plus"}, {"divide", "divide"}, {"positive", "identity"}, {"negative", "negate"}, {"modulus", "modulo"}, {"pmod", "pmod"}, {"abs", "abs"}, {"ceil", "ceil"}, {"round", "roundHalfUp"}, {"bround", "roundBankers"}, {"exp", "exp"}, {"power", "power"}, {"cos", "cos"}, {"cosh", "cosh"}, {"sin", "sin"}, {"sinh", "sinh"}, {"tan", "tan"}, {"tanh", "tanh"}, {"acos", "acos"}, {"asin", "asin"}, {"atan", "atan"}, {"atan2", "atan2"}, {"asinh", "asinh"}, {"acosh", "acosh"}, {"atanh", "atanh"}, {"bitwise_not", "bitNot"}, {"bitwise_and", "bitAnd"}, {"bitwise_or", "bitOr"}, {"bitwise_xor", "bitXor"}, {"sqrt", "sqrt"}, {"cbrt", "cbrt"}, {"degrees", "degrees"}, {"e", "e"}, {"pi", "pi"}, {"hex", "hex"}, {"unhex", "unhex"}, {"hypot", "hypot"}, {"sign", "sign"}, {"radians", "radians"}, {"greatest", "greatest"}, {"least", "least"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"}, {"check_overflow", "checkDecimalOverflowSpark"}, {"factorial", "factorial"}, {"rand", "randCanonical"}, {"isnan", "isNaN"}, /// string functions {"like", "like"}, {"not_like", "notLike"}, {"starts_with", "startsWithUTF8"}, {"ends_with", "endsWithUTF8"}, {"contains", "countSubstrings"}, {"substring", "substringUTF8"}, {"substring_index", "substringIndexUTF8"}, {"lower", "lowerUTF8"}, {"upper", "upperUTF8"}, {"trim", ""}, // trimLeft or trimLeftSpark, depends on argument size {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size {"concat", ""}, /// dummy mapping {"strpos", "positionUTF8"}, {"char_length", "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length {"replace", "replaceAll"}, {"regexp_replace", "replaceRegexpAll"}, {"regexp_extract", "regexpExtract"}, {"regexp_extract_all", "regexpExtractAllSpark"}, {"chr", "char"}, {"rlike", "match"}, {"ascii", "ascii"}, {"split", "splitByRegexp"}, {"concat_ws", "concat_ws"}, {"base64", "base64Encode"}, {"unbase64", "base64Decode"}, {"lpad", "leftPadUTF8"}, {"rpad", "rightPadUTF8"}, {"reverse", ""}, /// dummy mapping {"md5", "MD5"}, {"translate", "translateUTF8"}, {"repeat", "repeat"}, {"position", "positionUTF8Spark"}, {"locate", "positionUTF8Spark"}, {"space", "space"}, {"initcap", "initcapUTF8"}, {"conv", "sparkConv"}, {"uuid", "generateUUIDv4"}, /// hash functions {"sha1", "SHA1"}, {"sha2", ""}, /// dummy mapping {"crc32", "CRC32"}, {"murmur3hash", "sparkMurmurHash3_32"}, {"xxhash64", "sparkXxHash64"}, // in functions {"in", "in"}, // null related functions {"coalesce", "coalesce"}, // date or datetime functions {"from_unixtime", "fromUnixTimestampInJodaSyntax"}, {"date_add", "addDays"}, {"date_sub", "subtractDays"}, {"datediff", "dateDiff"}, {"second", "toSecond"}, {"add_months", "addMonths"}, {"date_trunc", "dateTrunc"}, {"floor_datetime", "dateTrunc"}, {"floor", "sparkFloor"}, {"months_between", "sparkMonthsBetween"}, // array functions {"array", "array"}, {"size", "length"}, {"range", "range"}, /// dummy mapping // map functions {"map", "map"}, {"get_map_value", "arrayElement"}, {"map_keys", "mapKeys"}, {"map_values", "mapValues"}, {"map_from_arrays", "mapFromArrays"}, // tuple functions {"get_struct_field", "sparkTupleElement"}, {"get_array_struct_fields", "sparkTupleElement"}, {"named_struct", "tuple"}, // table-valued generator function {"explode", "arrayJoin"}, {"posexplode", "arrayJoin"}, // json functions {"flattenJSONStringOnRequired", "flattenJSONStringOnRequired"}, {"get_json_object", "get_json_object"}, {"to_json", "toJSONString"}, {"from_json", "JSONExtract"}, {"json_tuple", "json_tuple"}, {"json_array_length", "JSONArrayLength"}, {"make_decimal", "makeDecimalSpark"}, {"unscaled_value", "unscaleValueSpark"}, // runtime filter {"might_contain", "bloomFilterContains"}}; static const std::set FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"}; struct QueryContext { StorageSnapshotPtr storage_snapshot; std::shared_ptr metadata; std::shared_ptr custom_storage_merge_tree; }; DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type); DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c); class SerializedPlanParser; // Give a condition expression `cond_rel_`, found all columns with nullability that must not containt // null after this filter. // It's used to remove nullability of the columns for performance reason. class NonNullableColumnsResolver { public: explicit NonNullableColumnsResolver(const DB::Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_); ~NonNullableColumnsResolver() = default; // return column names std::set resolve(); private: DB::Block header; SerializedPlanParser & parser; const substrait::Expression & cond_rel; std::set collected_columns; void visit(const substrait::Expression & expr); void visitNonNullable(const substrait::Expression & expr); String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function); }; class SerializedPlanParser { private: friend class RelParser; friend class RelRewriter; friend class ASTParser; friend class FunctionParser; friend class AggregateFunctionParser; friend class FunctionExecutor; friend class NonNullableColumnsResolver; friend class JoinRelParser; friend class MergeTreeRelParser; public: explicit SerializedPlanParser(const ContextPtr & context); DB::QueryPlanPtr parse(const std::string & plan); DB::QueryPlanPtr parseJson(const std::string & json_plan); DB::QueryPlanPtr parse(std::unique_ptr plan); DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel); DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel); PrewhereInfoPtr parsePreWhereInfo(const substrait::Expression & rel, Block & input); static bool isReadRelFromJava(const substrait::ReadRel & rel); static bool isReadFromMergeTree(const substrait::ReadRel & rel); static substrait::ReadRel::LocalFiles parseLocalFiles(const std::string & split_info); static substrait::ReadRel::ExtensionTable parseExtensionTable(const std::string & split_info); void addInputIter(jobject iter, bool materialize_input) { input_iters.emplace_back(iter); materialize_inputs.emplace_back(materialize_input); } void addSplitInfo(std::string & split_info) { split_infos.emplace_back(std::move(split_info)); } int nextSplitInfoIndex() { if (split_info_index >= split_infos.size()) throw Exception(ErrorCodes::LOGICAL_ERROR, "split info index out of range, split_info_index: {}, split_infos.size(): {}", split_info_index, split_infos.size()); return split_info_index++; } void parseExtensions(const ::google::protobuf::RepeatedPtrField & extensions); std::shared_ptr expressionsToActionsDAG( const std::vector & expressions, const DB::Block & header, const DB::Block & read_schema); RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); bool convertBinaryArithmeticFunDecimalArgs( ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun); IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); static ContextMutablePtr global_context; static Context::ConfigurationPtr config; static SharedContextHolder shared_context; QueryContext query_context; std::vector extra_plan_holder; private: static DB::NamesAndTypesList blockToNameAndTypeList(const DB::Block & header); DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list & rel_stack); void collectJoinKeys(const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start); DB::ActionsDAGPtr parseFunction( const Block & header, const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false); DB::ActionsDAGPtr parseFunctionOrExpression( const Block & header, const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false); DB::ActionsDAGPtr parseArrayJoin( const Block & input, const substrait::Expression & rel, std::vector & result_names, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false, bool position = false); DB::ActionsDAGPtr parseJsonTuple( const Block & input, const substrait::Expression & rel, std::vector & result_names, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false, bool position = false); const ActionsDAG::Node * parseFunctionWithDAG( const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false); ActionsDAG::NodeRawConstPtrs parseArrayJoinWithDAG( const substrait::Expression & rel, std::vector & result_name, DB::ActionsDAGPtr actions_dag = nullptr, bool keep_result = false, bool position = false); void parseFunctionArguments( DB::ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, std::string & function_name, const substrait::Expression_ScalarFunction & scalar_function); void parseFunctionArgument( DB::ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, const std::string & function_name, const substrait::FunctionArgument & arg); const DB::ActionsDAG::Node * parseFunctionArgument(DB::ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg); const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAGPtr actions_dag, const substrait::Expression & rel); const ActionsDAG::Node * toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args); // remove nullable after isNotNull void removeNullable(const std::set & require_columns, ActionsDAGPtr actions_dag); std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names); static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); int name_no = 0; std::unordered_map function_mapping; std::vector input_iters; std::vector split_infos; int split_info_index = 0; std::vector materialize_inputs; ContextPtr context; // for parse rel node, collect steps from a rel node std::vector temp_step_collection; std::vector metrics; }; struct SparkBuffer { char * address; size_t size; }; class LocalExecutor : public BlockIterator { public: LocalExecutor() = default; explicit LocalExecutor(QueryContext & _query_context, ContextPtr context); void execute(QueryPlanPtr query_plan); SparkRowInfoPtr next(); Block * nextColumnar(); bool hasNext(); ~LocalExecutor(); Block & getHeader(); RelMetricPtr getMetric() const { return metric; } void setMetric(RelMetricPtr metric_) { metric = metric_; } void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } private: QueryContext query_context; std::unique_ptr writeBlockToSparkRow(DB::Block & block); QueryPipeline query_pipeline; std::unique_ptr executor; Block header; ContextPtr context; std::unique_ptr ch_column_to_spark_row; std::unique_ptr spark_buffer; DB::QueryPlanPtr current_query_plan; RelMetricPtr metric; std::vector extra_plan_holder; /// Dump processor runtime information to log std::string dumpPipeline(); }; class ASTParser { public: explicit ASTParser(const ContextPtr & _context, std::unordered_map & _function_mapping) : context(_context), function_mapping(_function_mapping) { } ~ASTParser() = default; ASTPtr parseToAST(const Names & names, const substrait::Expression & rel); ActionsDAGPtr convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast); private: ContextPtr context; std::unordered_map function_mapping; void parseFunctionArgumentsToAST(const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args); ASTPtr parseArgumentToAST(const Names & names, const substrait::Expression & rel); }; }