std::shared_ptr SubstraitVeloxPlanConverter::toVeloxPlan()

in velox/substrait/SubstraitToVeloxPlan.cpp [22:171]


std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
    const ::substrait::AggregateRel& sAgg) {
  std::shared_ptr<const core::PlanNode> childNode;
  if (sAgg.has_input()) {
    childNode = toVeloxPlan(sAgg.input());
  } else {
    VELOX_FAIL("Child Rel is expected in AggregateRel.");
  }

  // Construct Velox grouping expressions.
  auto inputTypes = childNode->outputType();
  std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
      veloxGroupingExprs;
  const auto& groupings = sAgg.groupings();
  int inputPlanNodeId = planNodeId_ - 1;
  // The index of output column.
  int outIdx = 0;
  for (const auto& grouping : groupings) {
    auto groupingExprs = grouping.grouping_expressions();
    for (const auto& groupingExpr : groupingExprs) {
      // Velox's groupings are limited to be Field, so groupingExpr is
      // expected to be FieldReference.
      auto fieldExpr = exprConverter_->toVeloxExpr(
          groupingExpr.selection(), inputPlanNodeId);
      veloxGroupingExprs.emplace_back(fieldExpr);
      outIdx += 1;
    }
  }

  // Parse measures to get Aggregation phase and expressions.
  bool phaseInited = false;
  core::AggregationNode::Step aggStep;
  // Project expressions are used to conduct a pre-projection before
  // Aggregation if needed.
  std::vector<std::shared_ptr<const core::ITypedExpr>> projectExprs;
  std::vector<std::string> projectOutNames;
  std::vector<std::shared_ptr<const core::CallTypedExpr>> aggExprs;
  aggExprs.reserve(sAgg.measures().size());

  // Construct Velox Aggregate expressions.
  for (const auto& sMea : sAgg.measures()) {
    auto aggFunction = sMea.measure();
    // Get the params of this Aggregate function.
    std::vector<std::shared_ptr<const core::ITypedExpr>> aggParams;
    auto args = aggFunction.args();
    aggParams.reserve(args.size());
    for (auto arg : args) {
      auto typeCase = arg.rex_type_case();
      switch (typeCase) {
        case ::substrait::Expression::RexTypeCase::kSelection: {
          aggParams.emplace_back(
              exprConverter_->toVeloxExpr(arg.selection(), inputPlanNodeId));
          break;
        }
        case ::substrait::Expression::RexTypeCase::kScalarFunction: {
          // Pre-projection is needed before Aggregate.
          // The input of Aggregatation will be the output of the
          // pre-projection.
          auto sFunc = arg.scalar_function();
          projectExprs.emplace_back(
              exprConverter_->toVeloxExpr(sFunc, inputPlanNodeId));
          auto colOutName = subParser_->makeNodeName(planNodeId_, outIdx);
          projectOutNames.emplace_back(colOutName);
          auto outType = subParser_->parseType(sFunc.output_type());
          auto aggInputParam =
              std::make_shared<const core::FieldAccessTypedExpr>(
                  toVeloxType(outType->type), colOutName);
          aggParams.emplace_back(aggInputParam);
          break;
        }
        default:
          VELOX_NYI(
              "Substrait conversion not supported for arg type '{}'", typeCase);
      }
    }
    auto funcId = aggFunction.function_reference();
    auto funcName = subParser_->findVeloxFunction(functionMap_, funcId);
    auto aggOutType = subParser_->parseType(aggFunction.output_type());
    auto aggExpr = std::make_shared<const core::CallTypedExpr>(
        toVeloxType(aggOutType->type), std::move(aggParams), funcName);
    aggExprs.emplace_back(aggExpr);

    // Initialize the Aggregate Step.
    if (!phaseInited) {
      auto phase = aggFunction.phase();
      switch (phase) {
        case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
          aggStep = core::AggregationNode::Step::kPartial;
          break;
        case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
          aggStep = core::AggregationNode::Step::kIntermediate;
          break;
        case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
          aggStep = core::AggregationNode::Step::kFinal;
          break;
        default:
          VELOX_NYI("Substrait conversion not supported for phase '{}'", phase);
      }
      phaseInited = true;
    }
    outIdx += 1;
  }

  // Construct the Aggregate Node.
  bool ignoreNullKeys = false;
  std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggregateMasks(
      outIdx);
  std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
      preGroupingExprs;
  if (projectOutNames.size() == 0) {
    // Conduct Aggregation directly.
    std::vector<std::string> aggOutNames;
    aggOutNames.reserve(outIdx);
    for (int idx = 0; idx < outIdx; idx++) {
      aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
    }
    return std::make_shared<core::AggregationNode>(
        nextPlanNodeId(),
        aggStep,
        veloxGroupingExprs,
        preGroupingExprs,
        aggOutNames,
        aggExprs,
        aggregateMasks,
        ignoreNullKeys,
        childNode);
  } else {
    // A Project Node is needed before Aggregation.
    auto projectNode = std::make_shared<core::ProjectNode>(
        nextPlanNodeId(),
        std::move(projectOutNames),
        std::move(projectExprs),
        childNode);
    std::vector<std::string> aggOutNames;
    aggOutNames.reserve(outIdx);
    for (int idx = 0; idx < outIdx; idx++) {
      aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
    }
    return std::make_shared<core::AggregationNode>(
        nextPlanNodeId(),
        aggStep,
        veloxGroupingExprs,
        preGroupingExprs,
        aggOutNames,
        aggExprs,
        aggregateMasks,
        ignoreNullKeys,
        projectNode);
  }
}