in backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala [321:481]
private def getAggRelWithRowConstruct(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
inputRel: RelNode,
validation: Boolean): RelNode = {
// Create a projection for row construct.
val exprNodes = new JArrayList[ExpressionNode]()
groupingExpressions.foreach(
expr => {
exprNodes.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(context))
})
for (aggregateExpression <- aggregateExpressions) {
val aggFunc = aggregateExpression.aggregateFunction
val functionInputAttributes = aggFunc.inputAggBufferAttributes
aggFunc match {
case _ if aggregateExpression.mode == Partial =>
val childNodes = aggFunc.children
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(context)
)
.asJava
exprNodes.addAll(childNodes)
case _: HyperLogLogPlusPlus if aggFunc.aggBufferAttributes.size != 1 =>
throw new GlutenNotSupportException("Only one input attribute is expected.")
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
val rewrittenInputAttributes =
rewriteAggBufferAttributes(functionInputAttributes, originalInputAttributes)
// The process of handling the inconsistency in column types and order between
// Spark and Velox is exactly the opposite of applyExtractStruct.
aggregateExpression.mode match {
case PartialMerge | Final | Complete =>
val newInputAttributes = new ArrayBuffer[Attribute]()
val childNodes = new JArrayList[ExpressionNode]()
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = veloxOrders.map(o => sparkOrders.indexOf(o.head))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val adjustedIdx = adjustedOrders(idx)
if (adjustedIdx == -1) {
// The Velox aggregate intermediate buffer column not found in Spark.
// For example, skewness and kurtosis share the same aggregate buffer in Velox,
// and Kurtosis additionally requires the buffer column of m4, which is
// always 0 for skewness. In Spark, the aggregate buffer of skewness does not
// have the column of m4, thus a placeholder m4 with a value of 0 must be passed
// to Velox, and this value cannot be omitted. Velox will always read m4 column
// when accessing the intermediate data.
val extraAttr = AttributeReference(veloxOrders(idx).head, veloxType)()
newInputAttributes += extraAttr
val lt = Literal.default(veloxType)
childNodes.add(ExpressionBuilder.makeLiteral(lt.value, lt.dataType, false))
} else {
val sparkType = sparkTypes(adjustedIdx)
val attr = rewrittenInputAttributes(adjustedIdx)
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(context)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier)
ExpressionBuilder.makeCast(
ConverterUtils.getTypeNode(veloxType, attr.nullable),
aggFuncInputAttrNode,
SQLConf.get.ansiEnabled)
} else {
newInputAttributes += attr
aggFuncInputAttrNode
}
childNodes.add(expressionNode)
}
}
exprNodes.add(
getRowConstructNode(context, childNodes, newInputAttributes.toSeq, aggFunc))
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
case _ =>
val rewrittenInputAttributes =
rewriteAggBufferAttributes(functionInputAttributes, originalInputAttributes)
val childNodes = rewrittenInputAttributes
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(context)
)
.asJava
exprNodes.addAll(childNodes)
}
}
// Create a project rel.
val projectRel = RelBuilder.makeProjectRel(
originalInputAttributes.asJava,
inputRel,
exprNodes,
context,
operatorId,
validation)
// Create aggregation rel.
val groupingList = new JArrayList[ExpressionNode]()
var colIdx = 0
groupingExpressions.foreach {
_ =>
groupingList.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
}
val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
aggregateExpressions.foreach(
aggExpr => {
if (aggExpr.filter.isDefined) {
throw new GlutenNotSupportException("Filter in final aggregation is not supported.")
} else {
// The number of filters should be aligned with that of aggregate functions.
aggFilterList.add(null)
}
val aggFunc = aggExpr.aggregateFunction
val childrenNodes = new JArrayList[ExpressionNode]()
aggExpr.mode match {
case PartialMerge | Final =>
// Only occupies one column due to intermediate results are combined
// by previous projection.
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case Partial | Complete =>
aggFunc.children.foreach {
_ =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
}
case _ =>
throw new GlutenNotSupportException(
s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
}
addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList)
})
val extensionNode = getAdvancedExtension()
RelBuilder.makeAggregateRel(
projectRel,
groupingList,
aggregateFunctionList,
aggFilterList,
extensionNode,
context,
operatorId)
}