bool SubstraitToVeloxPlanValidator::validate()

in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [580:715]


bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windowRel) {
  if (windowRel.has_input() && !validate(windowRel.input())) {
    LOG_VALIDATION_MSG("WindowRel input fails to validate.");
    return false;
  }

  // Get and validate the input types from extension.
  if (!windowRel.has_advanced_extension()) {
    LOG_VALIDATION_MSG("Input types are expected in WindowRel.");
    return false;
  }
  const auto& extension = windowRel.advanced_extension();
  std::vector<TypePtr> types;
  if (!validateInputTypes(extension, types)) {
    LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
    return false;
  }

  int32_t inputPlanNodeId = 0;
  std::vector<std::string> names;
  names.reserve(types.size());
  for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
    names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
  }
  auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

  // Validate WindowFunction
  std::vector<std::string> funcSpecs;
  funcSpecs.reserve(windowRel.measures().size());
  for (const auto& smea : windowRel.measures()) {
    try {
      const auto& windowFunction = smea.measure();
      funcSpecs.emplace_back(planConverter_.findFuncSpec(windowFunction.function_reference()));
      SubstraitParser::parseType(windowFunction.output_type());
      for (const auto& arg : windowFunction.arguments()) {
        auto typeCase = arg.value().rex_type_case();
        switch (typeCase) {
          case ::substrait::Expression::RexTypeCase::kSelection:
          case ::substrait::Expression::RexTypeCase::kLiteral:
            break;
          default:
            LOG_VALIDATION_MSG("Only field is supported in window functions.");
            return false;
        }
      }
      // Validate BoundType and Frame Type
      switch (windowFunction.window_type()) {
        case ::substrait::WindowType::ROWS:
        case ::substrait::WindowType::RANGE:
          break;
        default:
          LOG_VALIDATION_MSG(
              "the window type only support ROWS and RANGE, and the input type is " +
              std::to_string(windowFunction.window_type()));
          return false;
      }

      bool boundTypeSupported =
          validateBoundType(windowFunction.upper_bound()) && validateBoundType(windowFunction.lower_bound());
      if (!boundTypeSupported) {
        LOG_VALIDATION_MSG(
            "Found unsupported Bound Type: upper " + std::to_string(windowFunction.upper_bound().kind_case()) +
            ", lower " + std::to_string(windowFunction.lower_bound().kind_case()));
        return false;
      }
    } catch (const VeloxException& err) {
      LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
      return false;
    }
  }

  // Validate supported aggregate functions.
  static const std::unordered_set<std::string> unsupportedFuncs = {"collect_list", "collect_set"};
  for (const auto& funcSpec : funcSpecs) {
    auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
    if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) {
      LOG_VALIDATION_MSG(funcName + " was not supported in WindowRel.");
      return false;
    }
  }

  // Validate groupby expression
  const auto& groupByExprs = windowRel.partition_expressions();
  std::vector<core::TypedExprPtr> expressions;
  expressions.reserve(groupByExprs.size());
  try {
    for (const auto& expr : groupByExprs) {
      auto expression = exprConverter_->toVeloxExpr(expr, rowType);
      auto expr_field = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
      if (expr_field == nullptr) {
        LOG_VALIDATION_MSG("Only field is supported for partition key in Window Operator!");
        return false;
      } else {
        expressions.emplace_back(expression);
      }
    }
    // Try to compile the expressions. If there is any unregistred funciton or
    // mismatched type, exception will be thrown.
    exec::ExprSet exprSet(std::move(expressions), execCtx_);
  } catch (const VeloxException& err) {
    LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
    return false;
  }

  // Validate Sort expression
  const auto& sorts = windowRel.sorts();
  for (const auto& sort : sorts) {
    switch (sort.direction()) {
      case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
      case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
      case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
      case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
        break;
      default:
        LOG_VALIDATION_MSG("in windowRel, unsupported Sort direction " + std::to_string(sort.direction()));
        return false;
    }

    if (sort.has_expr()) {
      try {
        auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType);
        auto expr_field = dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
        if (!expr_field) {
          LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field.");
          return false;
        }
        exec::ExprSet exprSet({std::move(expression)}, execCtx_);
      } catch (const VeloxException& err) {
        LOG_VALIDATION_MSG_FROM_EXCEPTION(err);
        return false;
      }
    }
  }

  return true;
}