bool SubstraitToVeloxPlanValidator::validate()

in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [1043:1179]


bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& aggRel) {
  if (aggRel.has_input() && !validate(aggRel.input())) {
    LOG_VALIDATION_MSG("Input validation fails in AggregateRel.");
    return false;
  }

  // Validate input types.
  if (aggRel.has_advanced_extension()) {
    std::vector<TypePtr> types;
    const auto& extension = aggRel.advanced_extension();
    // Aggregate always has advanced extension for streaming aggregate optimization,
    // but only some of them have enhancement for validation.
    if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
      LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
      return false;
    }
  }

  // Validate groupings.
  for (const auto& grouping : aggRel.groupings()) {
    for (const auto& groupingExpr : grouping.grouping_expressions()) {
      const auto& typeCase = groupingExpr.rex_type_case();
      switch (typeCase) {
        case ::substrait::Expression::RexTypeCase::kSelection:
          break;
        default:
          LOG_VALIDATION_MSG("Only field is supported in groupings.");
          return false;
      }
    }
  }

  // Validate aggregate functions.
  std::vector<std::string> funcSpecs;
  funcSpecs.reserve(aggRel.measures().size());
  for (const auto& smea : aggRel.measures()) {
    try {
      // Validate the filter expression
      if (smea.has_filter()) {
        ::substrait::Expression aggRelMask = smea.filter();
        if (aggRelMask.ByteSizeLong() > 0) {
          auto typeCase = aggRelMask.rex_type_case();
          switch (typeCase) {
            case ::substrait::Expression::RexTypeCase::kSelection:
              break;
            default:
              LOG_VALIDATION_MSG("Only field is supported in aggregate filter expression.");
              return false;
          }
        }
      }

      const auto& aggFunction = smea.measure();
      const auto& functionSpec = planConverter_.findFuncSpec(aggFunction.function_reference());
      funcSpecs.emplace_back(functionSpec);
      SubstraitParser::parseType(aggFunction.output_type());
      // Validate the size of arguments.
      if (SubstraitParser::getNameBeforeDelimiter(functionSpec) == "count" && aggFunction.arguments().size() > 1) {
        LOG_VALIDATION_MSG("Count should have only one argument.");
        // Count accepts only one argument.
        return false;
      }

      for (const auto& arg : aggFunction.arguments()) {
        auto typeCase = arg.value().rex_type_case();
        switch (typeCase) {
          case ::substrait::Expression::RexTypeCase::kSelection:
          case ::substrait::Expression::RexTypeCase::kLiteral:
            break;
          default:
            LOG_VALIDATION_MSG("Only field is supported in aggregate functions.");
            return false;
        }
      }
    } catch (const VeloxException& err) {
      LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
      return false;
    }
  }

  // The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not
  // needed to be registered.
  static const std::unordered_set<std::string> supportedAggFuncs = {
      "sum",
      "collect_set",
      "count",
      "avg",
      "min",
      "max",
      "min_by",
      "max_by",
      "stddev_samp",
      "stddev_pop",
      "bloom_filter_agg",
      "var_samp",
      "var_pop",
      "bit_and",
      "bit_or",
      "bit_xor",
      "first",
      "first_ignore_null",
      "last",
      "last_ignore_null",
      "corr",
      "covar_pop",
      "covar_samp",
      "approx_distinct"};

  for (const auto& funcSpec : funcSpecs) {
    auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
    if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end()) {
      LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel.");
      return false;
    }
  }

  if (!validateAggRelFunctionType(aggRel)) {
    return false;
  }

  // Validate both groupby and aggregates input are empty, which is corner case.
  if (aggRel.measures_size() == 0) {
    bool hasExpr = false;
    for (const auto& grouping : aggRel.groupings()) {
      if (grouping.grouping_expressions().size() > 0) {
        hasExpr = true;
        break;
      }
    }

    if (!hasExpr) {
      LOG_VALIDATION_MSG("Aggregation must specify either grouping keys or aggregates.");
      return false;
    }
  }
  return true;
}