private def doProcessConsumeWithKeys()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/fusion/spec/HashAggFusionCodegenSpec.scala [141:337]


  private def doProcessConsumeWithKeys(input: Seq[GeneratedExpression]): String = {
    // initialize the hashmap related code first
    val Seq(groupKeyTypesTerm, aggBufferTypesTerm) =
      newNames(opCodegenCtx, "groupKeyTypes", "aggBufferTypes")
    // gen code to create groupKey, aggBuffer Type array, it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
    HashAggCodeGenHelper.prepareHashAggKVTypes(
      opCodegenCtx,
      groupKeyTypesTerm,
      aggBufferTypesTerm,
      groupKeyRowType,
      aggBufferRowType)

    // create aggregate map
    val memorySizeTerm = newName(opCodegenCtx, "memorySize")
    val mapTypeTerm = classOf[BytesHashMap].getName
    opCodegenCtx.addReusableMember(s"private transient $mapTypeTerm $aggregateMapTerm;")
    opCodegenCtx.addReusableOpenStatement(
      s"""
         |long $memorySizeTerm = computeMemorySize(${fusionContext.getManagedMemoryFraction});
         |$aggregateMapTerm = new $mapTypeTerm(
         |  getContainingTask(),
         |  getContainingTask().getEnvironment().getMemoryManager(),
         |  $memorySizeTerm,
         |  $groupKeyTypesTerm,
         |  $aggBufferTypesTerm);
       """.stripMargin)

    // close aggregate map and release memory segments
    opCodegenCtx.addReusableCloseStatement(s"$aggregateMapTerm.free();")

    val Seq(currentKeyTerm, currentKeyWriterTerm) =
      newNames(opCodegenCtx, "currentKey", "currentKeyWriter")
    val Seq(lookupInfo, currentAggBufferTerm) =
      newNames(opCodegenCtx, "lookupInfo", "currentAggBuffer")
    val lookupInfoTypeTerm = classOf[BytesMap.LookupInfo[_, _]].getCanonicalName
    val binaryRowTypeTerm = classOf[BinaryRowData].getName

    // evaluate input field access for group key projection and aggregate buffer update
    val inputAccessCode = evaluateVariables(input)
    // project key row from input
    val keyExprs = grouping.map(idx => input(idx))
    val keyProjectionCode = getExprCodeGenerator
      .generateResultExpression(
        keyExprs,
        groupKeyRowType,
        classOf[BinaryRowData],
        currentKeyTerm,
        outRowWriter = Option(currentKeyWriterTerm))
      .code

    // gen code to create empty agg buffer, here need to consider the auxGrouping is not empty case
    val initedAggBuffer = genReusableEmptyAggBuffer(
      opCodegenCtx,
      builder,
      inputRowTerm,
      inputType,
      auxGrouping,
      aggInfos,
      aggBufferRowType)
    val lazyInitAggBufferCode = if (auxGrouping.isEmpty) {
      // create an empty agg buffer and initialized make it reusable
      opCodegenCtx.addReusableOpenStatement(initedAggBuffer.code)
      ""
    } else {
      s"""
         |// lazy init agg buffer (with auxGrouping)
         |${initedAggBuffer.code}
       """.stripMargin
    }

    // generate code to update agg buffer
    opCodegenCtx.startNewLocalVariableStatement(currentAggBufferTerm)
    val aggregateExpr = genAggregate(
      isMerge,
      opCodegenCtx,
      builder,
      inputType,
      inputRowTerm,
      auxGrouping,
      aggInfos,
      argsMapping,
      aggBuffMapping,
      currentAggBufferTerm,
      aggBufferRowType
    )

    // gen code to prepare agg output using agg buffer and key from the aggregate map
    val Seq(reuseAggMapKeyTerm, reuseAggBufferTerm) =
      newNames(opCodegenCtx, "reuseAggMapKey", "reuseAggBuffer")
    val rowData = classOf[RowData].getName
    opCodegenCtx.addReusableMember(s"private transient $rowData $reuseAggMapKeyTerm;")
    opCodegenCtx.addReusableMember(s"private transient $rowData $reuseAggBufferTerm;")
    // gen code to prepare agg output using agg buffer and key from the aggregate map
    val iteratorTerm = CodeGenUtils.newName(opCodegenCtx, "iterator")
    val iteratorType = classOf[KeyValueIterator[_, _]].getCanonicalName

    opCodegenCtx.startNewLocalVariableStatement(reuseAggBufferTerm)
    val reuseKeyExprs = getReuseRowFieldExprs(opCodegenCtx, groupKeyRowType, reuseAggMapKeyTerm)
    // get value expr from agg buffer
    getExprCodeGenerator
      .bindSecondInput(aggBufferRowType, reuseAggBufferTerm)
    val reuseValueExprs = genHashAggValueExpr(
      isMerge,
      isFinal,
      opCodegenCtx,
      getExprCodeGenerator,
      builder,
      auxGrouping,
      aggInfos,
      argsMapping,
      aggBuffMapping,
      inputType,
      reuseAggBufferTerm,
      aggBufferRowType
    )
    // reuse aggBuffer field access code if isFinal to avoid evaluate more times
    val reuseAggBufferFieldCode = if (isFinal) {
      opCodegenCtx.reuseInputUnboxingCode(reuseAggBufferTerm)
    } else {
      ""
    }
    outputFromMap =
      s"""
         |${opCodegenCtx.reuseLocalVariableCode(reuseAggBufferTerm)}
         |$iteratorType<$rowData, $rowData> $iteratorTerm =
         |  $aggregateMapTerm.getEntryIterator(false); // reuse key/value during iterating
         |while ($iteratorTerm.advanceNext()) {
         |   // set result and output
         |   $reuseAggMapKeyTerm = ($rowData)$iteratorTerm.getKey();
         |   $reuseAggBufferTerm = ($rowData)$iteratorTerm.getValue();
         |   // consume the row of agg produce
         |   $reuseAggBufferFieldCode
         |   ${outputResult(fusionContext.getOutputType, reuseKeyExprs ++ reuseValueExprs)}
         |}
       """.stripMargin

    val retryAppendCode = genRetryAppendToMap(
      aggregateMapTerm,
      currentKeyTerm,
      initedAggBuffer,
      lookupInfo,
      currentAggBufferTerm)

    // gen code to deal with hash map oom, if enable fallback we will use sort agg strategy
    val dealWithAggHashMapOOM =
      genHashAggOOMHandling(groupKeyTypesTerm, aggBufferTypesTerm, retryAppendCode)

    // generate the adaptive local hash agg code
    localAggSuppressedTerm = newName(opCodegenCtx, "localAggSuppressed")
    opCodegenCtx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
    val (
      distinctCountIncCode,
      totalCountIncCode,
      adaptiveSamplingCode,
      adaptiveLocalHashAggCode,
      flushResultSuppressEnableCode) = genAdaptiveLocalHashAgg(keyExprs)

    // process code
    s"""
       |do {
       |   // input field access
       |  $inputAccessCode
       |
       |  $adaptiveLocalHashAggCode
       |
       |  // project key from input
       |  $keyProjectionCode
       |
       |   // lookup output buffer using current group key
       |  $lookupInfoTypeTerm $lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
       |  $binaryRowTypeTerm $currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
       |  if (!$lookupInfo.isFound()) {
       |    $distinctCountIncCode
       |    $lazyInitAggBufferCode
       |    // append empty agg buffer into aggregate map for current group key
       |    try {
       |      $currentAggBufferTerm =
       |        $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
       |    } catch (java.io.EOFException exp) {
       |      $dealWithAggHashMapOOM
       |    }
       |  }
       |
       |  $totalCountIncCode
       |  $adaptiveSamplingCode
       |
       |  // do aggregate and update agg buffer
       |  ${opCodegenCtx.reuseLocalVariableCode(currentAggBufferTerm)}
       |  // aggregate buffer fields access
       |  ${opCodegenCtx.reuseInputUnboxingCode(currentAggBufferTerm)}
       |  
       |  ${aggregateExpr.code}
       |  // flush result form map if suppress is enable.
       |  $flushResultSuppressEnableCode
       |} while(false);
       |""".stripMargin.trim
  }