in velox/substrait/SubstraitToVeloxPlan.cpp [22:171]
std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::AggregateRel& sAgg) {
std::shared_ptr<const core::PlanNode> childNode;
if (sAgg.has_input()) {
childNode = toVeloxPlan(sAgg.input());
} else {
VELOX_FAIL("Child Rel is expected in AggregateRel.");
}
// Construct Velox grouping expressions.
auto inputTypes = childNode->outputType();
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
veloxGroupingExprs;
const auto& groupings = sAgg.groupings();
int inputPlanNodeId = planNodeId_ - 1;
// The index of output column.
int outIdx = 0;
for (const auto& grouping : groupings) {
auto groupingExprs = grouping.grouping_expressions();
for (const auto& groupingExpr : groupingExprs) {
// Velox's groupings are limited to be Field, so groupingExpr is
// expected to be FieldReference.
auto fieldExpr = exprConverter_->toVeloxExpr(
groupingExpr.selection(), inputPlanNodeId);
veloxGroupingExprs.emplace_back(fieldExpr);
outIdx += 1;
}
}
// Parse measures to get Aggregation phase and expressions.
bool phaseInited = false;
core::AggregationNode::Step aggStep;
// Project expressions are used to conduct a pre-projection before
// Aggregation if needed.
std::vector<std::shared_ptr<const core::ITypedExpr>> projectExprs;
std::vector<std::string> projectOutNames;
std::vector<std::shared_ptr<const core::CallTypedExpr>> aggExprs;
aggExprs.reserve(sAgg.measures().size());
// Construct Velox Aggregate expressions.
for (const auto& sMea : sAgg.measures()) {
auto aggFunction = sMea.measure();
// Get the params of this Aggregate function.
std::vector<std::shared_ptr<const core::ITypedExpr>> aggParams;
auto args = aggFunction.args();
aggParams.reserve(args.size());
for (auto arg : args) {
auto typeCase = arg.rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection: {
aggParams.emplace_back(
exprConverter_->toVeloxExpr(arg.selection(), inputPlanNodeId));
break;
}
case ::substrait::Expression::RexTypeCase::kScalarFunction: {
// Pre-projection is needed before Aggregate.
// The input of Aggregatation will be the output of the
// pre-projection.
auto sFunc = arg.scalar_function();
projectExprs.emplace_back(
exprConverter_->toVeloxExpr(sFunc, inputPlanNodeId));
auto colOutName = subParser_->makeNodeName(planNodeId_, outIdx);
projectOutNames.emplace_back(colOutName);
auto outType = subParser_->parseType(sFunc.output_type());
auto aggInputParam =
std::make_shared<const core::FieldAccessTypedExpr>(
toVeloxType(outType->type), colOutName);
aggParams.emplace_back(aggInputParam);
break;
}
default:
VELOX_NYI(
"Substrait conversion not supported for arg type '{}'", typeCase);
}
}
auto funcId = aggFunction.function_reference();
auto funcName = subParser_->findVeloxFunction(functionMap_, funcId);
auto aggOutType = subParser_->parseType(aggFunction.output_type());
auto aggExpr = std::make_shared<const core::CallTypedExpr>(
toVeloxType(aggOutType->type), std::move(aggParams), funcName);
aggExprs.emplace_back(aggExpr);
// Initialize the Aggregate Step.
if (!phaseInited) {
auto phase = aggFunction.phase();
switch (phase) {
case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kPartial;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE:
aggStep = core::AggregationNode::Step::kIntermediate;
break;
case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT:
aggStep = core::AggregationNode::Step::kFinal;
break;
default:
VELOX_NYI("Substrait conversion not supported for phase '{}'", phase);
}
phaseInited = true;
}
outIdx += 1;
}
// Construct the Aggregate Node.
bool ignoreNullKeys = false;
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggregateMasks(
outIdx);
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>>
preGroupingExprs;
if (projectOutNames.size() == 0) {
// Conduct Aggregation directly.
std::vector<std::string> aggOutNames;
aggOutNames.reserve(outIdx);
for (int idx = 0; idx < outIdx; idx++) {
aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
}
return std::make_shared<core::AggregationNode>(
nextPlanNodeId(),
aggStep,
veloxGroupingExprs,
preGroupingExprs,
aggOutNames,
aggExprs,
aggregateMasks,
ignoreNullKeys,
childNode);
} else {
// A Project Node is needed before Aggregation.
auto projectNode = std::make_shared<core::ProjectNode>(
nextPlanNodeId(),
std::move(projectOutNames),
std::move(projectExprs),
childNode);
std::vector<std::string> aggOutNames;
aggOutNames.reserve(outIdx);
for (int idx = 0; idx < outIdx; idx++) {
aggOutNames.emplace_back(subParser_->makeNodeName(planNodeId_, idx));
}
return std::make_shared<core::AggregationNode>(
nextPlanNodeId(),
aggStep,
veloxGroupingExprs,
preGroupingExprs,
aggOutNames,
aggExprs,
aggregateMasks,
ignoreNullKeys,
projectNode);
}
}