in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashAggFusionCodegenSpec.scala [141:337]
private def doProcessConsumeWithKeys(input: Seq[GeneratedExpression]): String = {
// initialize the hashmap related code first
val Seq(groupKeyTypesTerm, aggBufferTypesTerm) =
newNames(opCodegenCtx, "groupKeyTypes", "aggBufferTypes")
// gen code to create groupKey, aggBuffer Type array, it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
HashAggCodeGenHelper.prepareHashAggKVTypes(
opCodegenCtx,
groupKeyTypesTerm,
aggBufferTypesTerm,
groupKeyRowType,
aggBufferRowType)
// create aggregate map
val memorySizeTerm = newName(opCodegenCtx, "memorySize")
val mapTypeTerm = classOf[BytesHashMap].getName
opCodegenCtx.addReusableMember(s"private transient $mapTypeTerm $aggregateMapTerm;")
opCodegenCtx.addReusableOpenStatement(
s"""
|long $memorySizeTerm = computeMemorySize(${fusionContext.getManagedMemoryFraction});
|$aggregateMapTerm = new $mapTypeTerm(
| getContainingTask(),
| getContainingTask().getEnvironment().getMemoryManager(),
| $memorySizeTerm,
| $groupKeyTypesTerm,
| $aggBufferTypesTerm);
""".stripMargin)
// close aggregate map and release memory segments
opCodegenCtx.addReusableCloseStatement(s"$aggregateMapTerm.free();")
val Seq(currentKeyTerm, currentKeyWriterTerm) =
newNames(opCodegenCtx, "currentKey", "currentKeyWriter")
val Seq(lookupInfo, currentAggBufferTerm) =
newNames(opCodegenCtx, "lookupInfo", "currentAggBuffer")
val lookupInfoTypeTerm = classOf[BytesMap.LookupInfo[_, _]].getCanonicalName
val binaryRowTypeTerm = classOf[BinaryRowData].getName
// evaluate input field access for group key projection and aggregate buffer update
val inputAccessCode = evaluateVariables(input)
// project key row from input
val keyExprs = grouping.map(idx => input(idx))
val keyProjectionCode = getExprCodeGenerator
.generateResultExpression(
keyExprs,
groupKeyRowType,
classOf[BinaryRowData],
currentKeyTerm,
outRowWriter = Option(currentKeyWriterTerm))
.code
// gen code to create empty agg buffer, here need to consider the auxGrouping is not empty case
val initedAggBuffer = genReusableEmptyAggBuffer(
opCodegenCtx,
builder,
inputRowTerm,
inputType,
auxGrouping,
aggInfos,
aggBufferRowType)
val lazyInitAggBufferCode = if (auxGrouping.isEmpty) {
// create an empty agg buffer and initialized make it reusable
opCodegenCtx.addReusableOpenStatement(initedAggBuffer.code)
""
} else {
s"""
|// lazy init agg buffer (with auxGrouping)
|${initedAggBuffer.code}
""".stripMargin
}
// generate code to update agg buffer
opCodegenCtx.startNewLocalVariableStatement(currentAggBufferTerm)
val aggregateExpr = genAggregate(
isMerge,
opCodegenCtx,
builder,
inputType,
inputRowTerm,
auxGrouping,
aggInfos,
argsMapping,
aggBuffMapping,
currentAggBufferTerm,
aggBufferRowType
)
// gen code to prepare agg output using agg buffer and key from the aggregate map
val Seq(reuseAggMapKeyTerm, reuseAggBufferTerm) =
newNames(opCodegenCtx, "reuseAggMapKey", "reuseAggBuffer")
val rowData = classOf[RowData].getName
opCodegenCtx.addReusableMember(s"private transient $rowData $reuseAggMapKeyTerm;")
opCodegenCtx.addReusableMember(s"private transient $rowData $reuseAggBufferTerm;")
// gen code to prepare agg output using agg buffer and key from the aggregate map
val iteratorTerm = CodeGenUtils.newName(opCodegenCtx, "iterator")
val iteratorType = classOf[KeyValueIterator[_, _]].getCanonicalName
opCodegenCtx.startNewLocalVariableStatement(reuseAggBufferTerm)
val reuseKeyExprs = getReuseRowFieldExprs(opCodegenCtx, groupKeyRowType, reuseAggMapKeyTerm)
// get value expr from agg buffer
getExprCodeGenerator
.bindSecondInput(aggBufferRowType, reuseAggBufferTerm)
val reuseValueExprs = genHashAggValueExpr(
isMerge,
isFinal,
opCodegenCtx,
getExprCodeGenerator,
builder,
auxGrouping,
aggInfos,
argsMapping,
aggBuffMapping,
inputType,
reuseAggBufferTerm,
aggBufferRowType
)
// reuse aggBuffer field access code if isFinal to avoid evaluate more times
val reuseAggBufferFieldCode = if (isFinal) {
opCodegenCtx.reuseInputUnboxingCode(reuseAggBufferTerm)
} else {
""
}
outputFromMap =
s"""
|${opCodegenCtx.reuseLocalVariableCode(reuseAggBufferTerm)}
|$iteratorType<$rowData, $rowData> $iteratorTerm =
| $aggregateMapTerm.getEntryIterator(false); // reuse key/value during iterating
|while ($iteratorTerm.advanceNext()) {
| // set result and output
| $reuseAggMapKeyTerm = ($rowData)$iteratorTerm.getKey();
| $reuseAggBufferTerm = ($rowData)$iteratorTerm.getValue();
| // consume the row of agg produce
| $reuseAggBufferFieldCode
| ${outputResult(fusionContext.getOutputType, reuseKeyExprs ++ reuseValueExprs)}
|}
""".stripMargin
val retryAppendCode = genRetryAppendToMap(
aggregateMapTerm,
currentKeyTerm,
initedAggBuffer,
lookupInfo,
currentAggBufferTerm)
// gen code to deal with hash map oom, if enable fallback we will use sort agg strategy
val dealWithAggHashMapOOM =
genHashAggOOMHandling(groupKeyTypesTerm, aggBufferTypesTerm, retryAppendCode)
// generate the adaptive local hash agg code
localAggSuppressedTerm = newName(opCodegenCtx, "localAggSuppressed")
opCodegenCtx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
val (
distinctCountIncCode,
totalCountIncCode,
adaptiveSamplingCode,
adaptiveLocalHashAggCode,
flushResultSuppressEnableCode) = genAdaptiveLocalHashAgg(keyExprs)
// process code
s"""
|do {
| // input field access
| $inputAccessCode
|
| $adaptiveLocalHashAggCode
|
| // project key from input
| $keyProjectionCode
|
| // lookup output buffer using current group key
| $lookupInfoTypeTerm $lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
| $binaryRowTypeTerm $currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
| if (!$lookupInfo.isFound()) {
| $distinctCountIncCode
| $lazyInitAggBufferCode
| // append empty agg buffer into aggregate map for current group key
| try {
| $currentAggBufferTerm =
| $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
| } catch (java.io.EOFException exp) {
| $dealWithAggHashMapOOM
| }
| }
|
| $totalCountIncCode
| $adaptiveSamplingCode
|
| // do aggregate and update agg buffer
| ${opCodegenCtx.reuseLocalVariableCode(currentAggBufferTerm)}
| // aggregate buffer fields access
| ${opCodegenCtx.reuseInputUnboxingCode(currentAggBufferTerm)}
|
| ${aggregateExpr.code}
| // flush result form map if suppress is enable.
| $flushResultSuppressEnableCode
|} while(false);
|""".stripMargin.trim
}