bool SubstraitToVeloxPlanValidator::validate()

in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [1172:1315]


bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& aggRel) {
  if (aggRel.has_input() && !validate(aggRel.input())) {
    LOG_VALIDATION_MSG("Input validation fails in AggregateRel.");
    return false;
  }

  // Validate input types.
  if (aggRel.has_advanced_extension()) {
    TypePtr inputRowType;
    std::vector<TypePtr> types;
    const auto& extension = aggRel.advanced_extension();
    // Aggregate always has advanced extension for streaming aggregate optimization,
    // but only some of them have enhancement for validation.
    if (extension.has_enhancement() &&
        (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types))) {
      LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
      return false;
    }
  }

  // Validate groupings.
  for (const auto& grouping : aggRel.groupings()) {
    for (const auto& groupingExpr : grouping.grouping_expressions()) {
      const auto& typeCase = groupingExpr.rex_type_case();
      switch (typeCase) {
        case ::substrait::Expression::RexTypeCase::kSelection:
          break;
        default:
          LOG_VALIDATION_MSG("Only field is supported in groupings.");
          return false;
      }
    }
  }

  // Validate aggregate functions.
  std::vector<std::string> funcSpecs;
  funcSpecs.reserve(aggRel.measures().size());
  for (const auto& smea : aggRel.measures()) {
    // Validate the filter expression
    if (smea.has_filter()) {
      ::substrait::Expression aggRelMask = smea.filter();
      if (aggRelMask.ByteSizeLong() > 0) {
        auto typeCase = aggRelMask.rex_type_case();
        switch (typeCase) {
          case ::substrait::Expression::RexTypeCase::kSelection:
            break;
          default:
            LOG_VALIDATION_MSG("Only field is supported in aggregate filter expression.");
            return false;
        }
      }
    }

    const auto& aggFunction = smea.measure();
    const auto& functionSpec = planConverter_.findFuncSpec(aggFunction.function_reference());
    funcSpecs.emplace_back(functionSpec);
    SubstraitParser::parseType(aggFunction.output_type());
    // Validate the size of arguments.
    if (SubstraitParser::getNameBeforeDelimiter(functionSpec) == "count" && aggFunction.arguments().size() > 1) {
      LOG_VALIDATION_MSG("Count should have only one argument.");
      // Count accepts only one argument.
      return false;
    }

    for (const auto& arg : aggFunction.arguments()) {
      auto typeCase = arg.value().rex_type_case();
      switch (typeCase) {
        case ::substrait::Expression::RexTypeCase::kSelection:
        case ::substrait::Expression::RexTypeCase::kLiteral:
          break;
        default:
          LOG_VALIDATION_MSG("Only field is supported in aggregate functions.");
          return false;
      }
    }
  }

  // The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not
  // needed to be registered.
  static const std::unordered_set<std::string> supportedAggFuncs = {
      "sum",
      "collect_set",
      "collect_list",
      "count",
      "avg",
      "min",
      "max",
      "min_by",
      "max_by",
      "stddev_samp",
      "stddev_pop",
      "bloom_filter_agg",
      "var_samp",
      "var_pop",
      "bit_and",
      "bit_or",
      "bit_xor",
      "first",
      "first_ignore_null",
      "last",
      "last_ignore_null",
      "corr",
      "regr_r2",
      "covar_pop",
      "covar_samp",
      "approx_distinct",
      "skewness",
      "kurtosis",
      "regr_slope",
      "regr_intercept",
      "regr_sxy",
      "regr_replacement"};

  auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();

  for (const auto& funcSpec : funcSpecs) {
    auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
    if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) {
      LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel.");
      return false;
    }
  }

  if (!validateAggRelFunctionType(aggRel)) {
    return false;
  }

  // Validate both groupby and aggregates input are empty, which is corner case.
  if (aggRel.measures_size() == 0) {
    bool hasExpr = false;
    for (const auto& grouping : aggRel.groupings()) {
      if (grouping.grouping_expressions().size() > 0) {
        hasExpr = true;
        break;
      }
    }

    if (!hasExpr) {
      LOG_VALIDATION_MSG("Aggregation must specify either grouping keys or aggregates.");
      return false;
    }
  }
  return true;
}