in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala [45:352]
def genWithKeys(
ctx: CodeGeneratorContext,
builder: RelBuilder,
aggInfoList: AggregateInfoList,
inputType: RowType,
outputType: RowType,
grouping: Array[Int],
auxGrouping: Array[Int],
isMerge: Boolean,
isFinal: Boolean,
supportAdaptiveLocalHashAgg: Boolean,
maxNumFileHandles: Int,
compressionEnabled: Boolean,
compressionBlockSize: Int): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
val aggInfos = aggInfoList.aggInfos
val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
val aggBufferPrefix = "hash"
val aggBufferNames = AggCodeGenHelper.getAggBufferNames(aggBufferPrefix, auxGrouping, aggInfos)
val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val className = if (isFinal) "HashAggregateWithKeys" else "LocalHashAggregateWithKeys"
// add logger
val logTerm = CodeGenUtils.newName(ctx, "LOG")
ctx.addReusableLogger(logTerm, className)
// gen code to do group key projection from input
val currentKeyTerm = CodeGenUtils.newName(ctx, "currentKey")
val currentKeyWriterTerm = CodeGenUtils.newName(ctx, "currentKeyWriter")
// currentValueTerm and currentValueWriterTerm are used for value
// projection while supportAdaptiveLocalHashAgg is true.
val currentValueTerm = CodeGenUtils.newName(ctx, "currentValue")
val currentValueWriterTerm = CodeGenUtils.newName(ctx, "currentValueWriter")
val keyProjectionCode = ProjectionCodeGenerator
.generateProjectionExpression(
ctx,
inputType,
groupKeyRowType,
grouping,
inputTerm = inputTerm,
outRecordTerm = currentKeyTerm,
outRecordWriterTerm = currentKeyWriterTerm)
.code
// gen code to create groupKey, aggBuffer Type array
// it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
val groupKeyTypesTerm = CodeGenUtils.newName(ctx, "groupKeyTypes")
val aggBufferTypesTerm = CodeGenUtils.newName(ctx, "aggBufferTypes")
HashAggCodeGenHelper.prepareHashAggKVTypes(
ctx,
groupKeyTypesTerm,
aggBufferTypesTerm,
groupKeyRowType,
aggBufferRowType)
val binaryRowTypeTerm = classOf[BinaryRowData].getName
// gen code to aggregate and output using hash map
val aggregateMapTerm = CodeGenUtils.newName(ctx, "aggregateMap")
val lookupInfoTypeTerm = classOf[BytesMap.LookupInfo[_, _]].getCanonicalName
val lookupInfo = ctx.addReusableLocalVariable(lookupInfoTypeTerm, "lookupInfo")
HashAggCodeGenHelper.prepareHashAggMap(
ctx,
groupKeyTypesTerm,
aggBufferTypesTerm,
aggregateMapTerm)
val outputTerm = CodeGenUtils.newName(ctx, "hashAggOutput")
val (reuseGroupKeyTerm, reuseAggBufferTerm) =
HashAggCodeGenHelper.prepareTermForAggMapIteration(
ctx,
outputTerm,
outputType,
if (grouping.isEmpty) classOf[GenericRowData] else classOf[JoinedRowData])
val currentAggBufferTerm = ctx.addReusableLocalVariable(binaryRowTypeTerm, "currentAggBuffer")
val (initedAggBuffer, aggregate, outputExpr) = HashAggCodeGenHelper.genHashAggCodes(
isMerge,
isFinal,
ctx,
builder,
(grouping, auxGrouping),
inputTerm,
inputType,
aggInfos,
currentAggBufferTerm,
aggBufferRowType,
aggBufferTypes,
outputTerm,
outputType,
reuseGroupKeyTerm,
reuseAggBufferTerm
)
val outputResultFromMap = HashAggCodeGenHelper.genAggMapIterationAndOutput(
ctx,
isFinal,
aggregateMapTerm,
reuseGroupKeyTerm,
reuseAggBufferTerm,
outputExpr)
// gen code to deal with hash map oom, if enable fallback we will use sort agg strategy
val sorterTerm = CodeGenUtils.newName(ctx, "sorter")
val retryAppend = HashAggCodeGenHelper.genRetryAppendToMap(
aggregateMapTerm,
currentKeyTerm,
initedAggBuffer,
lookupInfo,
currentAggBufferTerm)
val (dealWithAggHashMapOOM, fallbackToSortAggCode) = HashAggCodeGenHelper.genAggMapOOMHandling(
isFinal,
ctx,
builder,
(grouping, auxGrouping),
aggInfos,
functionIdentifiers,
logTerm,
aggregateMapTerm,
(groupKeyTypesTerm, aggBufferTypesTerm),
(groupKeyRowType, aggBufferRowType),
aggBufferPrefix,
aggBufferNames,
aggBufferTypes,
outputTerm,
outputType,
outputResultFromMap,
sorterTerm,
retryAppend,
maxNumFileHandles,
compressionEnabled,
compressionBlockSize
)
HashAggCodeGenHelper.prepareMetrics(ctx, aggregateMapTerm, if (isFinal) sorterTerm else null)
// Do adaptive hash aggregation
val outputResultForAdaptiveLocalHashAgg = {
// gen code to iterating the aggregate map and output to downstream
val inputUnboxingCode = s"${ctx.reuseInputUnboxingCode(reuseAggBufferTerm)}"
s"""
| // set result and output
| $reuseGroupKeyTerm = ($ROW_DATA)$currentKeyTerm;
| $reuseAggBufferTerm = ($ROW_DATA)$currentValueTerm;
| $inputUnboxingCode
| ${outputExpr.code}
| ${OperatorCodeGenerator.generateCollect(outputExpr.resultTerm)}
|
""".stripMargin
}
val localAggSuppressedTerm = CodeGenUtils.newName(ctx, "localAggSuppressed")
ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
val valueProjectionCode =
if (!isFinal && supportAdaptiveLocalHashAgg) {
ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
ctx,
inputType,
classOf[BinaryRowData],
inputTerm = inputTerm,
aggInfos,
auxGrouping,
outRecordTerm = currentValueTerm,
outRecordWriterTerm = currentValueWriterTerm
)
} else {
""
}
val (
distinctCountIncCode,
totalCountIncCode,
adaptiveSamplingCode,
adaptiveLocalHashAggCode,
flushResultSuppressEnableCode) = {
// from these conditions we know that it must be a distinct operation
if (
!isFinal &&
ctx.tableConfig.get(ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED) &&
supportAdaptiveLocalHashAgg
) {
val adaptiveDistinctCountTerm = CodeGenUtils.newName(ctx, "distinctCount")
val adaptiveTotalCountTerm = CodeGenUtils.newName(ctx, "totalCount")
ctx.addReusableMember(s"private transient long $adaptiveDistinctCountTerm = 0;")
ctx.addReusableMember(s"private transient long $adaptiveTotalCountTerm = 0;")
val samplingThreshold =
ctx.tableConfig.get(
ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD)
val distinctValueRateThreshold =
ctx.tableConfig.get(
ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD)
(
s"$adaptiveDistinctCountTerm++;",
s"$adaptiveTotalCountTerm++;",
s"""
|if ($adaptiveTotalCountTerm == $samplingThreshold) {
| $logTerm.info("Local hash aggregation checkpoint reached, sampling threshold = " +
| $samplingThreshold + ", distinct value count = " + $adaptiveDistinctCountTerm + ", total = " +
| $adaptiveTotalCountTerm + ", distinct value rate threshold = "
| + $distinctValueRateThreshold);
| if ($adaptiveDistinctCountTerm / (1.0 * $adaptiveTotalCountTerm) > $distinctValueRateThreshold) {
| $logTerm.info("Local hash aggregation is suppressed");
| $localAggSuppressedTerm = true;
| }
|}
|""".stripMargin,
s"""
|if ($localAggSuppressedTerm) {
| $valueProjectionCode
| $outputResultForAdaptiveLocalHashAgg
| return;
|}
|""".stripMargin,
s"""
|if ($localAggSuppressedTerm) {
| $outputResultFromMap
| return;
|}
|""".stripMargin)
} else {
("", "", "", "", "")
}
}
val lazyInitAggBufferCode = if (auxGrouping.nonEmpty) {
s"""
|// lazy init agg buffer (with auxGrouping)
|${initedAggBuffer.code}
""".stripMargin
} else {
""
}
val processCode =
s"""
| // input field access for group key projection and aggregate buffer update
|${ctx.reuseInputUnboxingCode(inputTerm)}
| // project key from input
|$keyProjectionCode
|
|$adaptiveLocalHashAggCode
|
| // look up output buffer using current group key
|$lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
|$currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
|
|if (!$lookupInfo.isFound()) {
| $distinctCountIncCode
| $lazyInitAggBufferCode
| // append empty agg buffer into aggregate map for current group key
| try {
| $currentAggBufferTerm =
| $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
| } catch (java.io.EOFException exp) {
| $dealWithAggHashMapOOM
| }
|}
|
|$totalCountIncCode
|$adaptiveSamplingCode
|
| // aggregate buffer fields access
|${ctx.reuseInputUnboxingCode(currentAggBufferTerm)}
| // do aggregate and update agg buffer
|${aggregate.code}
| // flush result form map if suppress is enable.
|$flushResultSuppressEnableCode
|""".stripMargin.trim
val endInputCode = if (isFinal) {
val memPoolTypeTerm = classOf[BytesHashMapSpillMemorySegmentPool].getName
s"""
|if ($sorterTerm == null) {
| // no spilling, output by iterating aggregate map.
| $outputResultFromMap
|} else {
| // spill last part of input' aggregation output buffer
| $sorterTerm.sortAndSpill(
| $aggregateMapTerm.getRecordAreaMemorySegments(),
| $aggregateMapTerm.getNumElements(),
| new $memPoolTypeTerm($aggregateMapTerm.getBucketAreaMemorySegments()));
| // only release floating memory in advance.
| $aggregateMapTerm.free(true);
| // fall back to sort based aggregation
| $fallbackToSortAggCode
|}
""".stripMargin
} else {
s"""
|if (!$localAggSuppressedTerm) {
| $outputResultFromMap
|}
|""".stripMargin
}
AggCodeGenHelper.generateOperator(
ctx,
className,
classOf[TableStreamOperator[RowData]].getCanonicalName,
processCode,
endInputCode,
inputType)
}