in backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala [341:495]
private def getAggRelWithRowConstruct(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
inputRel: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Create a projection for row construct.
val exprNodes = new JArrayList[ExpressionNode]()
groupingExpressions.foreach(
expr => {
exprNodes.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(args))
})
for (aggregateExpression <- aggregateExpressions) {
val aggFunc = aggregateExpression.aggregateFunction
val functionInputAttributes = aggFunc.inputAggBufferAttributes
aggFunc match {
case _
if aggregateExpression.mode == Partial => // FIXME: Any difference with the last branch?
val childNodes = aggFunc.children
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.addAll(childNodes)
case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
throw new UnsupportedOperationException("Only one input attribute is expected.")
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
// The process of handling the inconsistency in column types and order between
// Spark and Velox is exactly the opposite of applyExtractStruct.
aggregateExpression.mode match {
case PartialMerge | Final =>
val newInputAttributes = new ArrayBuffer[Attribute]()
val childNodes = new JArrayList[ExpressionNode]()
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val sparkType = sparkTypes(adjustedOrders(idx))
val attr = functionInputAttributes(adjustedOrders(idx))
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes += attr
aggFuncInputAttrNode
}
childNodes.add(expressionNode)
}
exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes, aggFunc))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _ =>
val childNodes = functionInputAttributes
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args)
)
.asJava
exprNodes.addAll(childNodes)
}
}
// Create a project rel.
val emitStartIndex = originalInputAttributes.size
val projectRel = if (!validation) {
RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
inputRel,
exprNodes,
extensionNode,
context,
operatorId,
emitStartIndex)
}
// Create aggregation rel.
val groupingList = new JArrayList[ExpressionNode]()
var colIdx = 0
groupingExpressions.foreach {
_ =>
groupingList.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
}
val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
aggregateExpressions.foreach(
aggExpr => {
if (aggExpr.filter.isDefined) {
throw new UnsupportedOperationException("Filter in final aggregation is not supported.")
} else {
// The number of filters should be aligned with that of aggregate functions.
aggFilterList.add(null)
}
val aggFunc = aggExpr.aggregateFunction
val childrenNodes = new JArrayList[ExpressionNode]()
aggExpr.mode match {
case PartialMerge | Final =>
// Only occupies one column due to intermediate results are combined
// by previous projection.
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case Partial =>
aggFunc.children.foreach {
_ =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
}
case _ =>
throw new UnsupportedOperationException(
s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
}
addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})
val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
projectRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}