in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [580:715]
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windowRel) {
if (windowRel.has_input() && !validate(windowRel.input())) {
LOG_VALIDATION_MSG("WindowRel input fails to validate.");
return false;
}
// Get and validate the input types from extension.
if (!windowRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in WindowRel.");
return false;
}
const auto& extension = windowRel.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));
// Validate WindowFunction
std::vector<std::string> funcSpecs;
funcSpecs.reserve(windowRel.measures().size());
for (const auto& smea : windowRel.measures()) {
try {
const auto& windowFunction = smea.measure();
funcSpecs.emplace_back(planConverter_.findFuncSpec(windowFunction.function_reference()));
SubstraitParser::parseType(windowFunction.output_type());
for (const auto& arg : windowFunction.arguments()) {
auto typeCase = arg.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
default:
LOG_VALIDATION_MSG("Only field is supported in window functions.");
return false;
}
}
// Validate BoundType and Frame Type
switch (windowFunction.window_type()) {
case ::substrait::WindowType::ROWS:
case ::substrait::WindowType::RANGE:
break;
default:
LOG_VALIDATION_MSG(
"the window type only support ROWS and RANGE, and the input type is " +
std::to_string(windowFunction.window_type()));
return false;
}
bool boundTypeSupported =
validateBoundType(windowFunction.upper_bound()) && validateBoundType(windowFunction.lower_bound());
if (!boundTypeSupported) {
LOG_VALIDATION_MSG(
"Found unsupported Bound Type: upper " + std::to_string(windowFunction.upper_bound().kind_case()) +
", lower " + std::to_string(windowFunction.lower_bound().kind_case()));
return false;
}
} catch (const VeloxException& err) {
LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
return false;
}
}
// Validate supported aggregate functions.
static const std::unordered_set<std::string> unsupportedFuncs = {"collect_list", "collect_set"};
for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) {
LOG_VALIDATION_MSG(funcName + " was not supported in WindowRel.");
return false;
}
}
// Validate groupby expression
const auto& groupByExprs = windowRel.partition_expressions();
std::vector<core::TypedExprPtr> expressions;
expressions.reserve(groupByExprs.size());
try {
for (const auto& expr : groupByExprs) {
auto expression = exprConverter_->toVeloxExpr(expr, rowType);
auto expr_field = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
if (expr_field == nullptr) {
LOG_VALIDATION_MSG("Only field is supported for partition key in Window Operator!");
return false;
} else {
expressions.emplace_back(expression);
}
}
// Try to compile the expressions. If there is any unregistred funciton or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
} catch (const VeloxException& err) {
LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
return false;
}
// Validate Sort expression
const auto& sorts = windowRel.sorts();
for (const auto& sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
break;
default:
LOG_VALIDATION_MSG("in windowRel, unsupported Sort direction " + std::to_string(sort.direction()));
return false;
}
if (sort.has_expr()) {
try {
auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType);
auto expr_field = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
if (!expr_field) {
LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field.");
return false;
}
exec::ExprSet exprSet({std::move(expression)}, execCtx_);
} catch (const VeloxException& err) {
LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
return false;
}
}
}
return true;
}