in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [1172:1315]
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& aggRel) {
if (aggRel.has_input() && !validate(aggRel.input())) {
LOG_VALIDATION_MSG("Input validation fails in AggregateRel.");
return false;
}
// Validate input types.
if (aggRel.has_advanced_extension()) {
TypePtr inputRowType;
std::vector<TypePtr> types;
const auto& extension = aggRel.advanced_extension();
// Aggregate always has advanced extension for streaming aggregate optimization,
// but only some of them have enhancement for validation.
if (extension.has_enhancement() &&
(!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types))) {
LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
return false;
}
}
// Validate groupings.
for (const auto& grouping : aggRel.groupings()) {
for (const auto& groupingExpr : grouping.grouping_expressions()) {
const auto& typeCase = groupingExpr.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
break;
default:
LOG_VALIDATION_MSG("Only field is supported in groupings.");
return false;
}
}
}
// Validate aggregate functions.
std::vector<std::string> funcSpecs;
funcSpecs.reserve(aggRel.measures().size());
for (const auto& smea : aggRel.measures()) {
// Validate the filter expression
if (smea.has_filter()) {
::substrait::Expression aggRelMask = smea.filter();
if (aggRelMask.ByteSizeLong() > 0) {
auto typeCase = aggRelMask.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
break;
default:
LOG_VALIDATION_MSG("Only field is supported in aggregate filter expression.");
return false;
}
}
}
const auto& aggFunction = smea.measure();
const auto& functionSpec = planConverter_.findFuncSpec(aggFunction.function_reference());
funcSpecs.emplace_back(functionSpec);
SubstraitParser::parseType(aggFunction.output_type());
// Validate the size of arguments.
if (SubstraitParser::getNameBeforeDelimiter(functionSpec) == "count" && aggFunction.arguments().size() > 1) {
LOG_VALIDATION_MSG("Count should have only one argument.");
// Count accepts only one argument.
return false;
}
for (const auto& arg : aggFunction.arguments()) {
auto typeCase = arg.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
default:
LOG_VALIDATION_MSG("Only field is supported in aggregate functions.");
return false;
}
}
}
// The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not
// needed to be registered.
static const std::unordered_set<std::string> supportedAggFuncs = {
"sum",
"collect_set",
"collect_list",
"count",
"avg",
"min",
"max",
"min_by",
"max_by",
"stddev_samp",
"stddev_pop",
"bloom_filter_agg",
"var_samp",
"var_pop",
"bit_and",
"bit_or",
"bit_xor",
"first",
"first_ignore_null",
"last",
"last_ignore_null",
"corr",
"regr_r2",
"covar_pop",
"covar_samp",
"approx_distinct",
"skewness",
"kurtosis",
"regr_slope",
"regr_intercept",
"regr_sxy",
"regr_replacement"};
auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();
for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) {
LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel.");
return false;
}
}
if (!validateAggRelFunctionType(aggRel)) {
return false;
}
// Validate both groupby and aggregates input are empty, which is corner case.
if (aggRel.measures_size() == 0) {
bool hasExpr = false;
for (const auto& grouping : aggRel.groupings()) {
if (grouping.grouping_expressions().size() > 0) {
hasExpr = true;
break;
}
}
if (!hasExpr) {
LOG_VALIDATION_MSG("Aggregation must specify either grouping keys or aggregates.");
return false;
}
}
return true;
}