in sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala [628:866]
protected override def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, bindReferences[Expression](groupingExpressions, child.output))
val fastRowKeys = ctx.generateExpressions(
bindReferences[Expression](groupingExpressions, child.output))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash")
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
// To individually generate code for each aggregate function, an element in `updateExprs` holds
// all the expressions for the buffer of an aggregation function.
val updateExprs = aggregateExpressions.map { e =>
// only have DeclarativeAggregate
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match {
case Some((_, regularMapCounter)) =>
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;")
case _ => ("true", "", "")
}
val oomeClassName = classOf[SparkOutOfMemoryError].getName
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
|${unsafeRowKeyCode.code}
|int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
|if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash);
|}
|// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
|// aggregation after processing all input rows.
|if ($unsafeRowBuffer == null) {
| if ($sorterTerm == null) {
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
| } else {
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
| }
| $resetCounter
| // the hash map had be spilled, it should have enough memory now,
| // try to allocate buffer again.
| $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
| $unsafeRowKeys, $unsafeRowKeyHash);
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
| throw new $oomeClassName("AGGREGATE_OUT_OF_MEMORY", new java.util.HashMap());
| }
|}
""".stripMargin
val findOrInsertHashMap: String = {
if (isFastHashMapEnabled) {
// If fast hash map is on, we first generate code to probe and update the fast hash map.
// If the probe is successful the corresponding fast row buffer will hold the mutable row.
s"""
|${fastRowKeys.map(_.code).mkString("\n")}
|if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
| ${fastRowKeys.map(_.value).mkString(", ")});
|}
|// Cannot find the key in fast hash map, try regular hash map.
|if ($fastRowBuffer == null) {
| $findOrInsertRegularHashMap
|}
""".stripMargin
} else {
findOrInsertRegularHashMap
}
}
val inputAttrs = aggregateBufferAttributes ++ inputAttributes
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
// generating input columns, we use `currentVars`.
ctx.currentVars = (new Array[ExprCode](aggregateBufferAttributes.length) ++ input)
.toImmutableArraySeq
val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
// Computes start offsets for each aggregation function code
// in the underlying buffer row.
val bufferStartOffsets = {
val offsets = mutable.ArrayBuffer[Int]()
var curOffset = 0
updateExprs.foreach { exprsForOneFunc =>
offsets += curOffset
curOffset += exprsForOneFunc.length
}
offsets.toArray
}
val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
}
val aggCodeBlocks = updateExprs.indices.map { i =>
val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i)
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
val bufferOffset = bufferStartOffsets(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
val updateExpr = boundUpdateExprsForOneFunc(j)
val dt = updateExpr.dataType
val nullable = updateExpr.nullable
CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable)
}
code"""
|${ctx.registerComment(s"evaluate aggregate function for ${aggNames(i)}")}
|${evaluateVariables(rowBufferEvalsForOneFunc)}
|${ctx.registerComment("update unsafe row buffer")}
|${updateRowBuffers.mkString("\n").trim}
""".stripMargin
}
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate functions and update aggregation buffers
|$codeToEvalAggFuncs
""".stripMargin
}
val updateRowInHashMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
}
val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) =>
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
val bufferOffset = bufferStartOffsets(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
val updateExpr = boundUpdateExprsForOneFunc(j)
val dt = updateExpr.dataType
val nullable = updateExpr.nullable
CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable,
isVectorized = true)
}
code"""
|${ctx.registerComment(s"evaluate aggregate function for ${aggNames(i)}")}
|${evaluateVariables(fastRowEvalsForOneFunc)}
|${ctx.registerComment("update fast row")}
|${updateRowBuffer.mkString("\n").trim}
""".stripMargin
}
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
// If vectorized fast hash map is on, we first generate code to update row
// in vectorized fast hash map, if the previous loop up hit vectorized fast hash map.
// Otherwise, update row in regular hash map.
s"""
|if ($fastRowBuffer != null) {
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate functions and update aggregation buffers
| $codeToEvalAggFuncs
|} else {
| $updateRowInRegularHashMap
|}
""".stripMargin
} else {
// If row-based hash map is on and the previous loop up hit fast hash map,
// we reuse regular hash buffer to update row of fast hash map.
// Otherwise, update row in regular hash map.
s"""
|// Updates the proper row buffer
|if ($fastRowBuffer != null) {
| $unsafeRowBuffer = $fastRowBuffer;
|}
|$updateRowInRegularHashMap
""".stripMargin
}
} else {
updateRowInRegularHashMap
}
}
val declareRowBuffer: String = if (isFastHashMapEnabled) {
val fastRowType = if (isVectorizedHashMapEnabled) {
classOf[MutableColumnarRow].getName
} else {
"UnsafeRow"
}
s"""
|UnsafeRow $unsafeRowBuffer = null;
|$fastRowType $fastRowBuffer = null;
""".stripMargin
} else {
s"UnsafeRow $unsafeRowBuffer = null;"
}
// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
|$declareRowBuffer
|$findOrInsertHashMap
|$incCounter
|$updateRowInHashMap
""".stripMargin
}