in sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala [429:626]
protected override def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap()
} else if (conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
val bitMaxCapacity = testFallbackStartsAt match {
case Some((fastMapCounter, _)) =>
// In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit
// of map to be no more than log2(`fastMapCounter`). This helps control the number of keys
// in map to mimic fall back.
if (fastMapCounter <= 1) {
0
} else {
(math.log10(fastMapCounter) / math.log10(2)).floor.toInt
}
case _ => conf.fastHashAggregateRowMaxCapacityBit
}
val thisPlan = ctx.addReferenceObj("plan", this)
// Create a name for the iterator from the fast hash map, and the code to create fast hash map.
val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) {
// Generates the fast hash map class and creates the fast hash map term.
val fastHashMapClassName = ctx.freshName("FastHashMap")
if (isVectorizedHashMapEnabled) {
val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)
// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"java.util.Iterator<InternalRow>",
"vectorizedFastHashMapIter",
forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
(iter, create)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate()
ctx.addInnerClass(generatedMap)
// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(
fastHashMapClassName, "fastHashMap", forceInline = true)
val iter = ctx.addMutableState(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter", forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
s"$thisPlan.getTaskContext().taskMemoryManager(), " +
s"$thisPlan.getEmptyAggregationBuffer());"
(iter, create)
}
} else ("", "")
// Generates the code to register a cleanup task with TaskContext to ensure that memory
// is guaranteed to be freed at the end of the task. This is necessary to avoid memory
// leaks in when the downstream operator does not fully consume the aggregation map's
// output (e.g. aggregate followed by limit).
val addHookToCloseFastHashMap = if (isFastHashMapEnabled) {
s"""
|$thisPlan.getTaskContext().addTaskCompletionListener(
| new org.apache.spark.util.TaskCompletionListener() {
| @Override
| public void onTaskCompletion(org.apache.spark.TaskContext context) {
| $fastHashMapTerm.close();
| }
|});
""".stripMargin
} else ""
// Create a name for the iterator from the regular hash map.
// Inline mutable state since not many aggregation operations in a task
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
"mapIter", forceInline = true)
// create hashMap
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap", forceInline = true)
sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter",
forceInline = true)
val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
val numTasksFallBacked = metricTerm(ctx, "numTasksFallBacked")
val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe, $numTasksFallBacked);"
val finishHashMap = if (isFastHashMapEnabled) {
s"""
|$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
|$finishRegularHashMap
""".stripMargin
} else {
finishRegularHashMap
}
val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
|private void $doAgg() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| $finishHashMap
|}
""".stripMargin)
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)
val limitNotReachedCondition = limitNotReachedCond
def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
outputFromVectorizedMap
} else {
outputFromRowBasedMap
}
} else ""
}
def outputFromRowBasedMap: String = {
s"""
|while ($limitNotReachedCondition $iterTermForFastHashMap.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
|
| if (shouldStop()) return;
|}
|$fastHashMapTerm.close();
""".stripMargin
}
// Iterate over the aggregate rows and convert them from InternalRow to UnsafeRow
def outputFromVectorizedMap: String = {
val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
ctx.INPUT_ROW = row
val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
toAttributes(groupingKeySchema).zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
)
val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
toAttributes(bufferSchema).zipWithIndex.map { case (attr, i) =>
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
|while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
|
| if (shouldStop()) return;
|}
|
|$fastHashMapTerm.close();
""".stripMargin
}
def outputFromRegularHashMap: String = {
s"""
|while ($limitNotReachedCondition $iterTerm.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
| if (shouldStop()) return;
|}
|$iterTerm.close();
|if ($sorterTerm == null) {
| $hashMapTerm.free();
|}
""".stripMargin
}
val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
s"""
|if (!$initAgg) {
| $initAgg = true;
| $createFastHashMap
| $addHookToCloseFastHashMap
| $hashMapTerm = $thisPlan.createHashMap();
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
| $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
|}
|// output the result
|$outputFromFastHashMap
|$outputFromRegularHashMap
""".stripMargin
}