cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc (1,193 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "SubstraitToVeloxPlanValidator.h" #include <google/protobuf/wrappers.pb.h> #include <re2/re2.h> #include <string> #include "TypeUtils.h" #include "udf/UdfLoader.h" #include "utils/Common.h" #include "velox/exec/Aggregate.h" #include "velox/expression/Expr.h" #include "velox/expression/SignatureBinder.h" namespace gluten { namespace { const char* extractFileName(const char* file) { return strrchr(file, '/') ? strrchr(file, '/') + 1 : file; } #define LOG_VALIDATION_MSG_FROM_EXCEPTION(err) \ logValidateMsg(fmt::format( \ "Validation failed due to exception caught at file:{} line:{} function:{}, thrown from file:{} line:{} function:{}, reason:{}", \ extractFileName(__FILE__), \ __LINE__, \ __FUNCTION__, \ extractFileName(err.file()), \ err.line(), \ err.function(), \ err.message())) #define LOG_VALIDATION_MSG(reason) \ logValidateMsg(fmt::format( \ "Validation failed at file:{}, line:{}, function:{}, reason:{}", \ extractFileName(__FILE__), \ __LINE__, \ __FUNCTION__, \ reason)) const std::unordered_set<std::string> kRegexFunctions = { "regexp_extract", "regexp_extract_all", "regexp_replace", "rlike"}; const std::unordered_set<std::string> kBlackList = {"split_part", "trunc", "sequence", "approx_percentile", "get_array_struct_fields", "map_from_arrays"}; } // namespace bool SubstraitToVeloxPlanValidator::parseVeloxType( const ::substrait::extensions::AdvancedExtension& extension, TypePtr& out) { ::substrait::Type substraitType; // The input type is wrapped in enhancement. if (!extension.has_enhancement()) { LOG_VALIDATION_MSG("Input type is not wrapped in enhancement."); return false; } const auto& enhancement = extension.enhancement(); if (!enhancement.UnpackTo(&substraitType)) { LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType."); return false; } out = SubstraitParser::parseType(substraitType); return true; } bool SubstraitToVeloxPlanValidator::flattenSingleLevel(const TypePtr& type, std::vector<TypePtr>& out) { if (!type->isRow()) { LOG_VALIDATION_MSG("Type is not a RowType."); return false; } auto rowType = std::dynamic_pointer_cast<const RowType>(type); if (!rowType) { LOG_VALIDATION_MSG("Failed to cast to RowType."); return false; } for (const auto& field : rowType->children()) { out.emplace_back(field); } return true; } bool SubstraitToVeloxPlanValidator::flattenDualLevel(const TypePtr& type, std::vector<std::vector<TypePtr>>& out) { if (!type->isRow()) { LOG_VALIDATION_MSG("Type is not a RowType."); return false; } auto rowType = std::dynamic_pointer_cast<const RowType>(type); if (!rowType) { LOG_VALIDATION_MSG("Failed to cast to RowType."); return false; } for (const auto& field : rowType->children()) { std::vector<TypePtr> inner; if (!flattenSingleLevel(field, inner)) { return false; } out.emplace_back(inner); } return true; } bool SubstraitToVeloxPlanValidator::validateRound( const ::substrait::Expression::ScalarFunction& scalarFunction, const RowTypePtr& inputType) { const auto& arguments = scalarFunction.arguments(); if (arguments.size() < 2) { return false; } if (!arguments[1].value().has_literal()) { LOG_VALIDATION_MSG("Round scale is expected."); return false; } // Velox has different result with Spark on negative scale. auto typeCase = arguments[1].value().literal().literal_type_case(); switch (typeCase) { case ::substrait::Expression_Literal::LiteralTypeCase::kI32: return (arguments[1].value().literal().i32() >= 0); case ::substrait::Expression_Literal::LiteralTypeCase::kI64: return (arguments[1].value().literal().i64() >= 0); default: LOG_VALIDATION_MSG("Round scale validation is not supported for type case " + std::to_string(typeCase)); return false; } } bool SubstraitToVeloxPlanValidator::validateExtractExpr(const std::vector<core::TypedExprPtr>& params) { if (params.size() != 2) { LOG_VALIDATION_MSG("Value expected in variant in ExtractExpr."); return false; } auto functionArg = std::dynamic_pointer_cast<const core::ConstantTypedExpr>(params[0]); if (functionArg) { // Get the function argument. const auto& variant = functionArg->value(); if (!variant.hasValue()) { LOG_VALIDATION_MSG("Value expected in variant in ExtractExpr."); return false; } return true; } LOG_VALIDATION_MSG("Constant is expected to be the first parameter in extract."); return false; } bool SubstraitToVeloxPlanValidator::validateRegexExpr( const std::string& name, const ::substrait::Expression::ScalarFunction& scalarFunction) { if (scalarFunction.arguments().size() < 2) { LOG_VALIDATION_MSG("Wrong number of arguments for " + name); } const auto& patternArg = scalarFunction.arguments()[1].value(); if (!patternArg.has_literal() || !patternArg.literal().has_string()) { LOG_VALIDATION_MSG("Pattern is not string literal for " + name); return false; } const auto& pattern = patternArg.literal().string(); std::string error; if (!validatePattern(pattern, error)) { LOG_VALIDATION_MSG(name + " due to " + error); return false; } return true; } bool SubstraitToVeloxPlanValidator::validateScalarFunction( const ::substrait::Expression::ScalarFunction& scalarFunction, const RowTypePtr& inputType) { std::vector<core::TypedExprPtr> params; params.reserve(scalarFunction.arguments().size()); for (const auto& argument : scalarFunction.arguments()) { if (argument.has_value() && !validateExpression(argument.value(), inputType)) { return false; } params.emplace_back(exprConverter_->toVeloxExpr(argument.value(), inputType)); } const auto& function = SubstraitParser::findFunctionSpec(planConverter_.getFunctionMap(), scalarFunction.function_reference()); const auto& name = SubstraitParser::getNameBeforeDelimiter(function); std::vector<std::string> types = SubstraitParser::getSubFunctionTypes(function); if (name == "round") { return validateRound(scalarFunction, inputType); } if (name == "extract") { return validateExtractExpr(params); } // Validate regex functions. if (kRegexFunctions.find(name) != kRegexFunctions.end()) { return validateRegexExpr(name, scalarFunction); } if (kBlackList.find(name) != kBlackList.end()) { LOG_VALIDATION_MSG("Function is not supported: " + name); return false; } return true; } bool isSupportedArrayCast(const TypePtr& fromType, const TypePtr& toType) { static const std::unordered_set<TypeKind> kAllowedArrayElementKinds = { TypeKind::DOUBLE, TypeKind::BOOLEAN, TypeKind::TIMESTAMP, }; // https://github.com/apache/incubator-gluten/issues/9392 // is currently WIP to add support for other types. if (toType->isVarchar()) { return kAllowedArrayElementKinds.count(fromType->kind()) > 0; } if (toType->isDouble()) { if (fromType->isInteger() || fromType->isBigint() || fromType->isSmallint() || fromType->isTinyint()) { return true; } } return false; } bool SubstraitToVeloxPlanValidator::isAllowedCast(const TypePtr& fromType, const TypePtr& toType) { // Currently cast is not allowed for various categories, code has a bunch of rules // which define the cast categories and if we should offload to velox. Currently, // the following categories are denied. // // 1. from/to isIntervalYearMonth is not allowed. // 2. Date to most categories except few supported types is not allowed. // 3. Timestamp to most categories except few supported types is not allowed. // 4. Certain complex types are not allowed. // Don't support isIntervalYearMonth. if (fromType->isIntervalYearMonth() || toType->isIntervalYearMonth()) { return false; } // Limited support for DATE to X. if (fromType->isDate() && !toType->isTimestamp() && !toType->isVarchar()) { return false; } // Limited support for Timestamp to X. if (fromType->isTimestamp()) { if (toType->isDecimal()) { return false; } if (toType->isBigint()) { return true; } if (toType->isDate() || toType->isVarchar()) { return true; } return false; } // Limited support for X to Timestamp. if (toType->isTimestamp()) { if (fromType->isDecimal()) { return false; } if (fromType->isDate()) { return true; } if (fromType->isVarchar()) { return true; } if (fromType->isTinyint() || fromType->isSmallint() || fromType->isInteger() || fromType->isBigint() || fromType->isDouble() || fromType->isReal()) { return true; } return false; } if (fromType->isArray() && toType->isArray()) { const auto& toElem = toType->asArray().elementType(); const auto& fromElem = fromType->asArray().elementType(); if (!isAllowedCast(fromElem, toElem)) { return false; } return isSupportedArrayCast(fromElem, toElem); } // Limited support for Complex types. if (fromType->isArray() || fromType->isMap() || fromType->isRow()) { return false; } if (fromType->isVarbinary() && !toType->isVarchar()) { return false; } return true; } bool SubstraitToVeloxPlanValidator::validateCast( const ::substrait::Expression::Cast& castExpr, const RowTypePtr& inputType) { if (!validateExpression(castExpr.input(), inputType)) { return false; } const auto& toType = SubstraitParser::parseType(castExpr.type()); core::TypedExprPtr input = exprConverter_->toVeloxExpr(castExpr.input(), inputType); if (SubstraitToVeloxPlanValidator::isAllowedCast(input->type(), toType)) { return true; } LOG_VALIDATION_MSG("Casting from " + input->type()->toString() + " to " + toType->toString() + " is not supported."); return false; } bool SubstraitToVeloxPlanValidator::validateIfThen( const ::substrait::Expression_IfThen& ifThen, const RowTypePtr& inputType) { for (const auto& subIfThen : ifThen.ifs()) { if (!validateExpression(subIfThen.if_(), inputType) || !validateExpression(subIfThen.then(), inputType)) { return false; } } if (ifThen.has_else_() && !validateExpression(ifThen.else_(), inputType)) { return false; } return true; } bool SubstraitToVeloxPlanValidator::validateSingularOrList( const ::substrait::Expression::SingularOrList& singularOrList, const RowTypePtr& inputType) { for (const auto& option : singularOrList.options()) { if (!option.has_literal()) { LOG_VALIDATION_MSG("Option is expected as Literal."); return false; } } return validateExpression(singularOrList.value(), inputType); } bool SubstraitToVeloxPlanValidator::validateExpression( const ::substrait::Expression& expression, const RowTypePtr& inputType) { auto typeCase = expression.rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kScalarFunction: return validateScalarFunction(expression.scalar_function(), inputType); case ::substrait::Expression::RexTypeCase::kCast: return validateCast(expression.cast(), inputType); case ::substrait::Expression::RexTypeCase::kIfThen: return validateIfThen(expression.if_then(), inputType); case ::substrait::Expression::RexTypeCase::kSingularOrList: return validateSingularOrList(expression.singular_or_list(), inputType); default: return true; } } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeRel) { if (writeRel.has_input() && !validate(writeRel.input())) { LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel."); return false; } // Validate input data type. TypePtr inputRowType; std::vector<TypePtr> types; if (writeRel.has_named_table()) { const auto& extension = writeRel.named_table().advanced_extension(); if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel."); return false; } } // Validate partition key type. if (writeRel.has_table_schema()) { const auto& tableSchema = writeRel.table_schema(); std::vector<ColumnType> columnTypes; SubstraitParser::parseColumnTypes(tableSchema, columnTypes); for (auto i = 0; i < types.size(); i++) { if (columnTypes[i] == ColumnType::kPartitionKey) { switch (types[i]->kind()) { case TypeKind::BOOLEAN: case TypeKind::TINYINT: case TypeKind::SMALLINT: case TypeKind::INTEGER: case TypeKind::BIGINT: case TypeKind::VARCHAR: case TypeKind::VARBINARY: break; default: LOG_VALIDATION_MSG( "Validation failed for input type validation in WriteRel, not support partition column type: " + mapTypeKindToName(types[i]->kind())); return false; } } } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchRel) { // Get and validate the input types from extension. if (fetchRel.has_advanced_extension()) { const auto& extension = fetchRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in FetchRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } } if (fetchRel.offset() < 0 || fetchRel.count() < 0) { LOG_VALIDATION_MSG("Offset and count should be valid in FetchRel."); return false; } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::TopNRel& topNRel) { RowTypePtr rowType = nullptr; // Get and validate the input types from extension. if (topNRel.has_advanced_extension()) { const auto& extension = topNRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in TopNRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } rowType = std::make_shared<RowType>(std::move(names), std::move(types)); } if (topNRel.n() < 0) { LOG_VALIDATION_MSG("N should be valid in TopNRel."); return false; } auto [sortingKeys, sortingOrders] = planConverter_.processSortField(topNRel.sorts(), rowType); folly::F14FastSet<std::string> sortingKeyNames; for (const auto& sortingKey : sortingKeys) { auto result = sortingKeyNames.insert(sortingKey->name()); if (!result.second) { LOG_VALIDATION_MSG("Duplicate sort keys were found in TopNRel."); return false; } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::GenerateRel& generateRel) { if (generateRel.has_input() && !validate(generateRel.input())) { LOG_VALIDATION_MSG("Input validation fails in GenerateRel."); return false; } // Get and validate the input types from extension. if (!generateRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in GenerateRel."); return false; } const auto& extension = generateRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel."); return false; } int32_t inputPlanNodeId = 0; // Create the fake input names to be used in row type. std::vector<std::string> names; names.reserve(types.size()); for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); if (generateRel.has_generator() && !validateExpression(generateRel.generator(), rowType)) { LOG_VALIDATION_MSG("Input validation fails in GenerateRel."); return false; } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expandRel) { if (expandRel.has_input() && !validate(expandRel.input())) { LOG_VALIDATION_MSG("Input validation fails in ExpandRel."); return false; } RowTypePtr rowType = nullptr; // Get and validate the input types from extension. if (expandRel.has_advanced_extension()) { const auto& extension = expandRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in ExpandRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } rowType = std::make_shared<RowType>(std::move(names), std::move(types)); } int32_t projectSize = 0; // Validate fields. for (const auto& fields : expandRel.fields()) { std::vector<core::TypedExprPtr> expressions; if (fields.has_switching_field()) { auto projectExprs = fields.switching_field().duplicates(); expressions.reserve(projectExprs.size()); if (projectSize == 0) { projectSize = projectExprs.size(); } else if (projectSize != projectExprs.size()) { LOG_VALIDATION_MSG("SwitchingField expressions size should be constant in ExpandRel."); return false; } for (const auto& projectExpr : projectExprs) { const auto& typeCase = projectExpr.rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: case ::substrait::Expression::RexTypeCase::kLiteral: break; default: LOG_VALIDATION_MSG("Only field or literal is supported in project of ExpandRel."); return false; } if (rowType) { expressions.emplace_back(exprConverter_->toVeloxExpr(projectExpr, rowType)); } } if (rowType) { // Try to compile the expressions. If there is any unregistered // function or mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); } } else { LOG_VALIDATION_MSG("Only SwitchingField is supported in ExpandRel."); return false; } } return true; } bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) { switch (boundType.kind_case()) { case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing: case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding: case ::substrait::Expression_WindowFunction_Bound::kCurrentRow: case ::substrait::Expression_WindowFunction_Bound::kFollowing: case ::substrait::Expression_WindowFunction_Bound::kPreceding: break; default: return false; } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windowRel) { if (windowRel.has_input() && !validate(windowRel.input())) { LOG_VALIDATION_MSG("WindowRel input fails to validate."); return false; } // Get and validate the input types from extension. if (!windowRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in WindowRel."); return false; } const auto& extension = windowRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowRel."); return false; } if (types.empty()) { // See: https://github.com/apache/incubator-gluten/issues/7600. LOG_VALIDATION_MSG("Validation failed for empty input schema in WindowRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); // Validate WindowFunction std::vector<std::string> funcSpecs; funcSpecs.reserve(windowRel.measures().size()); for (const auto& smea : windowRel.measures()) { const auto& windowFunction = smea.measure(); funcSpecs.emplace_back(planConverter_.findFuncSpec(windowFunction.function_reference())); SubstraitParser::parseType(windowFunction.output_type()); for (const auto& arg : windowFunction.arguments()) { auto typeCase = arg.value().rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: case ::substrait::Expression::RexTypeCase::kLiteral: break; default: LOG_VALIDATION_MSG("Only field or constant is supported in window functions."); return false; } } // Validate BoundType and Frame Type switch (windowFunction.window_type()) { case ::substrait::WindowType::ROWS: case ::substrait::WindowType::RANGE: break; default: LOG_VALIDATION_MSG( "the window type only support ROWS and RANGE, and the input type is " + std::to_string(windowFunction.window_type())); return false; } bool boundTypeSupported = validateBoundType(windowFunction.upper_bound()) && validateBoundType(windowFunction.lower_bound()); if (!boundTypeSupported) { LOG_VALIDATION_MSG( "Found unsupported Bound Type: upper " + std::to_string(windowFunction.upper_bound().kind_case()) + ", lower " + std::to_string(windowFunction.lower_bound().kind_case())); return false; } } // Validate groupby expression const auto& groupByExprs = windowRel.partition_expressions(); std::vector<core::TypedExprPtr> expressions; expressions.reserve(groupByExprs.size()); for (const auto& expr : groupByExprs) { auto expression = exprConverter_->toVeloxExpr(expr, rowType); auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get()); if (exprField == nullptr) { LOG_VALIDATION_MSG("Only field is supported for partition key in Window Operator!"); return false; } else { expressions.emplace_back(expression); } } // Try to compile the expressions. If there is any unregistred funciton or // mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); // Validate Sort expression const auto& sorts = windowRel.sorts(); for (const auto& sort : sorts) { switch (sort.direction()) { case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: break; default: LOG_VALIDATION_MSG("in windowRel, unsupported Sort direction " + std::to_string(sort.direction())); return false; } if (sort.has_expr()) { auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get()); if (!exprField) { LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field."); return false; } exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get()); } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel) { if (windowGroupLimitRel.has_input() && !validate(windowGroupLimitRel.input())) { LOG_VALIDATION_MSG("WindowGroupLimitRel input fails to validate."); return false; } // Get and validate the input types from extension. if (!windowGroupLimitRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in WindowGroupLimitRel."); return false; } const auto& extension = windowGroupLimitRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowGroupLimitRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); // Validate groupby expression const auto& groupByExprs = windowGroupLimitRel.partition_expressions(); std::vector<core::TypedExprPtr> expressions; expressions.reserve(groupByExprs.size()); for (const auto& expr : groupByExprs) { auto expression = exprConverter_->toVeloxExpr(expr, rowType); auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get()); if (exprField == nullptr) { LOG_VALIDATION_MSG("Only field is supported for partition key in Window Group Limit Operator!"); return false; } expressions.emplace_back(expression); } // Try to compile the expressions. If there is any unregistered function or // mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); // Validate Sort expression const auto& sorts = windowGroupLimitRel.sorts(); for (const auto& sort : sorts) { switch (sort.direction()) { case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: break; default: LOG_VALIDATION_MSG("in windowGroupLimitRel, unsupported Sort direction " + std::to_string(sort.direction())); return false; } if (sort.has_expr()) { auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get()); if (!exprField) { LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field."); return false; } exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get()); } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel& setRel) { switch (setRel.op()) { case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: { for (int32_t i = 0; i < setRel.inputs_size(); ++i) { const auto& input = setRel.inputs(i); if (!validate(input)) { LOG_VALIDATION_MSG("ProjectRel input"); return false; } } if (!setRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in SetRel."); return false; } const auto& extension = setRel.advanced_extension(); TypePtr inputRowType; std::vector<std::vector<TypePtr>> childrenTypes; if (!parseVeloxType(extension, inputRowType) || !flattenDualLevel(inputRowType, childrenTypes)) { LOG_VALIDATION_MSG("Validation failed for input types in SetRel."); return false; } std::vector<RowTypePtr> childrenRowTypes; for (auto i = 0; i < childrenTypes.size(); ++i) { auto& types = childrenTypes.at(i); std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(i, colIdx)); } childrenRowTypes.push_back(std::make_shared<RowType>(std::move(names), std::move(types))); } for (auto i = 1; i < childrenRowTypes.size(); ++i) { if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) { LOG_VALIDATION_MSG( "All sources of the Set operation must have the same output type: " + childrenRowTypes[i]->toString() + " vs. " + childrenRowTypes[0]->toString()); return false; } } return true; } default: LOG_VALIDATION_MSG("Unsupported SetRel op: " + std::to_string(setRel.op())); return false; } } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel) { if (sortRel.has_input() && !validate(sortRel.input())) { return false; } // Get and validate the input types from extension. if (!sortRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in SortRel."); return false; } const auto& extension = sortRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in SortRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); const auto& sorts = sortRel.sorts(); for (const auto& sort : sorts) { switch (sort.direction()) { case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: break; default: LOG_VALIDATION_MSG("unsupported Sort direction " + std::to_string(sort.direction())); return false; } if (sort.has_expr()) { auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); auto exprField = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get()); if (!exprField) { LOG_VALIDATION_MSG("in SortRel, the sorting key in Sort Operator only support field."); return false; } exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& projectRel) { if (projectRel.has_input() && !validate(projectRel.input())) { LOG_VALIDATION_MSG("ProjectRel input"); return false; } // Get and validate the input types from extension. if (!projectRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in ProjectRel."); return false; } const auto& extension = projectRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel."); return false; } int32_t inputPlanNodeId = 0; // Create the fake input names to be used in row type. std::vector<std::string> names; names.reserve(types.size()); for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); // Validate the project expressions. const auto& projectExprs = projectRel.expressions(); std::vector<core::TypedExprPtr> expressions; expressions.reserve(projectExprs.size()); for (const auto& expr : projectExprs) { if (!validateExpression(expr, rowType)) { return false; } expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType)); } // Try to compile the expressions. If there is any unregistered function or // mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filterRel) { if (filterRel.has_input() && !validate(filterRel.input())) { LOG_VALIDATION_MSG("input of FilterRel validation fails"); return false; } // Get and validate the input types from extension. if (!filterRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in FilterRel."); return false; } const auto& extension = filterRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in FilterRel."); return false; } int32_t inputPlanNodeId = 0; // Create the fake input names to be used in row type. std::vector<std::string> names; names.reserve(types.size()); for (uint32_t colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); std::vector<core::TypedExprPtr> expressions; if (!validateExpression(filterRel.condition(), rowType)) { return false; } expressions.emplace_back(exprConverter_->toVeloxExpr(filterRel.condition(), rowType)); // Try to compile the expressions. If there is any unregistered function // or mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel) { if (joinRel.has_left() && !validate(joinRel.left())) { LOG_VALIDATION_MSG("Validation fails for join left input."); return false; } if (joinRel.has_right() && !validate(joinRel.right())) { LOG_VALIDATION_MSG("Validation fails for join right input."); return false; } if (joinRel.has_advanced_extension() && SubstraitParser::configSetInOptimization(joinRel.advanced_extension(), "isSMJ=")) { switch (joinRel.type()) { case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI: break; default: LOG_VALIDATION_MSG("Sort merge join type is not supported: " + std::to_string(joinRel.type())); return false; } } switch (joinRel.type()) { case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI: case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI: break; default: LOG_VALIDATION_MSG("Join type is not supported: " + std::to_string(joinRel.type())); return false; } // Validate input types. if (!joinRel.has_advanced_extension()) { LOG_VALIDATION_MSG("Input types are expected in JoinRel."); return false; } const auto& extension = joinRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in JoinRel."); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); if (joinRel.has_expression()) { std::vector<const ::substrait::Expression::FieldReference*> leftExprs, rightExprs; planConverter_.extractJoinKeys(joinRel.expression(), leftExprs, rightExprs); } if (joinRel.has_post_join_filter()) { auto expression = exprConverter_->toVeloxExpr(joinRel.post_join_filter(), rowType); exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossRel) { if (crossRel.has_left() && !validate(crossRel.left())) { logValidateMsg("Native validation failed due to: validation fails for cross join left input. "); return false; } if (crossRel.has_right() && !validate(crossRel.right())) { logValidateMsg("Native validation failed due to: validation fails for cross join right input. "); return false; } // Validate input types. if (!crossRel.has_advanced_extension()) { logValidateMsg("Native validation failed due to: Input types are expected in CrossRel."); return false; } switch (crossRel.type()) { case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER: case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: break; case ::substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: if (crossRel.has_expression()) { LOG_VALIDATION_MSG("Full outer join type with condition is not supported in CrossRel"); return false; } else { break; } default: LOG_VALIDATION_MSG("Unsupported Join type in CrossRel"); return false; } const auto& extension = crossRel.advanced_extension(); TypePtr inputRowType; std::vector<TypePtr> types; if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) { logValidateMsg("Native validation failed due to: Validation failed for input types in CrossRel"); return false; } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(types.size()); for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(types)); if (crossRel.has_expression()) { auto expression = exprConverter_->toVeloxExpr(crossRel.expression(), rowType); exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } return true; } bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait::AggregateRel& aggRel) { if (aggRel.measures_size() == 0) { return true; } for (const auto& smea : aggRel.measures()) { const auto& aggFunction = smea.measure(); const auto& funcStep = planConverter_.toAggregationFunctionStep(aggFunction); auto funcSpec = planConverter_.findFuncSpec(aggFunction.function_reference()); std::vector<TypePtr> types; bool isDecimal = false; types = SubstraitParser::sigToTypes(funcSpec); for (const auto& type : types) { if (!isDecimal && type->isDecimal()) { isDecimal = true; } } auto baseFuncName = SubstraitParser::mapToVeloxFunction(SubstraitParser::getNameBeforeDelimiter(funcSpec), isDecimal); auto funcName = planConverter_.toAggregationFunctionName(baseFuncName, funcStep); auto signaturesOpt = exec::getAggregateFunctionSignatures(funcName); if (!signaturesOpt) { LOG_VALIDATION_MSG("can not find function signature for " + funcName + " in AggregateRel."); return false; } bool resolved = false; for (const auto& signature : signaturesOpt.value()) { exec::SignatureBinder binder(*signature, types); if (binder.tryBind()) { auto resolveType = binder.tryResolveType( exec::isPartialOutput(funcStep) ? signature->intermediateType() : signature->returnType()); if (resolveType == nullptr) { LOG_VALIDATION_MSG("Validation failed for function " + funcName + " resolve type in AggregateRel."); return false; } resolved = true; break; } } if (!resolved) { LOG_VALIDATION_MSG("Validation failed for function " + funcName + " bind signatures in AggregateRel."); return false; } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& aggRel) { if (aggRel.has_input() && !validate(aggRel.input())) { LOG_VALIDATION_MSG("Input validation fails in AggregateRel."); return false; } // Validate input types. if (aggRel.has_advanced_extension()) { TypePtr inputRowType; std::vector<TypePtr> types; const auto& extension = aggRel.advanced_extension(); // Aggregate always has advanced extension for streaming aggregate optimization, // but only some of them have enhancement for validation. if (extension.has_enhancement() && (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types))) { LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel."); return false; } } // Validate groupings. for (const auto& grouping : aggRel.groupings()) { for (const auto& groupingExpr : grouping.grouping_expressions()) { const auto& typeCase = groupingExpr.rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: break; default: LOG_VALIDATION_MSG("Only field is supported in groupings."); return false; } } } // Validate aggregate functions. std::vector<std::string> funcSpecs; funcSpecs.reserve(aggRel.measures().size()); for (const auto& smea : aggRel.measures()) { // Validate the filter expression if (smea.has_filter()) { ::substrait::Expression aggRelMask = smea.filter(); if (aggRelMask.ByteSizeLong() > 0) { auto typeCase = aggRelMask.rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: break; default: LOG_VALIDATION_MSG("Only field is supported in aggregate filter expression."); return false; } } } const auto& aggFunction = smea.measure(); const auto& functionSpec = planConverter_.findFuncSpec(aggFunction.function_reference()); funcSpecs.emplace_back(functionSpec); SubstraitParser::parseType(aggFunction.output_type()); // Validate the size of arguments. if (SubstraitParser::getNameBeforeDelimiter(functionSpec) == "count" && aggFunction.arguments().size() > 1) { LOG_VALIDATION_MSG("Count should have only one argument."); // Count accepts only one argument. return false; } for (const auto& arg : aggFunction.arguments()) { auto typeCase = arg.value().rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: case ::substrait::Expression::RexTypeCase::kLiteral: break; default: LOG_VALIDATION_MSG("Only field is supported in aggregate functions."); return false; } } } // The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not // needed to be registered. static const std::unordered_set<std::string> supportedAggFuncs = { "sum", "collect_set", "collect_list", "count", "avg", "min", "max", "min_by", "max_by", "stddev_samp", "stddev_pop", "bloom_filter_agg", "var_samp", "var_pop", "bit_and", "bit_or", "bit_xor", "first", "first_ignore_null", "last", "last_ignore_null", "corr", "regr_r2", "covar_pop", "covar_samp", "approx_distinct", "skewness", "kurtosis", "regr_slope", "regr_intercept", "regr_sxy", "regr_replacement"}; auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames(); for (const auto& funcSpec : funcSpecs) { auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec); if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) { LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel."); return false; } } if (!validateAggRelFunctionType(aggRel)) { return false; } // Validate both groupby and aggregates input are empty, which is corner case. if (aggRel.measures_size() == 0) { bool hasExpr = false; for (const auto& grouping : aggRel.groupings()) { if (grouping.grouping_expressions().size() > 0) { hasExpr = true; break; } } if (!hasExpr) { LOG_VALIDATION_MSG("Aggregation must specify either grouping keys or aggregates."); return false; } } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ReadRel& readRel) { planConverter_.toVeloxPlan(readRel); // Validate filter in ReadRel. if (readRel.has_filter()) { std::vector<TypePtr> veloxTypeList; if (readRel.has_base_schema()) { const auto& baseSchema = readRel.base_schema(); veloxTypeList = SubstraitParser::parseNamedStruct(baseSchema); } int32_t inputPlanNodeId = 0; std::vector<std::string> names; names.reserve(veloxTypeList.size()); for (auto colIdx = 0; colIdx < veloxTypeList.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } auto rowType = std::make_shared<RowType>(std::move(names), std::move(veloxTypeList)); std::vector<core::TypedExprPtr> expressions; if (!validateExpression(readRel.filter(), rowType)) { return false; } expressions.emplace_back(exprConverter_->toVeloxExpr(readRel.filter(), rowType)); // Try to compile the expressions. If there is any unregistered function // or mismatched type, exception will be thrown. exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); } return true; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& rel) { if (rel.has_aggregate()) { return validate(rel.aggregate()); } if (rel.has_project()) { return validate(rel.project()); } if (rel.has_filter()) { return validate(rel.filter()); } if (rel.has_join()) { return validate(rel.join()); } if (rel.has_cross()) { return validate(rel.cross()); } if (rel.has_read()) { return validate(rel.read()); } if (rel.has_sort()) { return validate(rel.sort()); } if (rel.has_expand()) { return validate(rel.expand()); } if (rel.has_generate()) { return validate(rel.generate()); } if (rel.has_fetch()) { return validate(rel.fetch()); } if (rel.has_top_n()) { return validate(rel.top_n()); } if (rel.has_window()) { return validate(rel.window()); } if (rel.has_write()) { return validate(rel.write()); } if (rel.has_windowgrouplimit()) { return validate(rel.windowgrouplimit()); } if (rel.has_set()) { return validate(rel.set()); } LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName()); return false; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::RelRoot& relRoot) { if (relRoot.has_input()) { return validate(relRoot.input()); } return false; } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Plan& plan) { try { // Create plan converter and expression converter to help the validation. planConverter_.constructFunctionMap(plan); exprConverter_ = planConverter_.getExprConverter(); for (const auto& rel : plan.relations()) { if (rel.has_root()) { return validate(rel.root()); } if (rel.has_rel()) { return validate(rel.rel()); } } return false; } catch (const VeloxException& err) { LOG_VALIDATION_MSG_FROM_EXCEPTION(err); return false; } } } // namespace gluten