def genWithKeys()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala [45:352]


  def genWithKeys(
      ctx: CodeGeneratorContext,
      builder: RelBuilder,
      aggInfoList: AggregateInfoList,
      inputType: RowType,
      outputType: RowType,
      grouping: Array[Int],
      auxGrouping: Array[Int],
      isMerge: Boolean,
      isFinal: Boolean,
      supportAdaptiveLocalHashAgg: Boolean,
      maxNumFileHandles: Int,
      compressionEnabled: Boolean,
      compressionBlockSize: Int): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {

    val aggInfos = aggInfoList.aggInfos
    val functionIdentifiers = AggCodeGenHelper.getFunctionIdentifiers(aggInfos)
    val aggBufferPrefix = "hash"
    val aggBufferNames = AggCodeGenHelper.getAggBufferNames(aggBufferPrefix, auxGrouping, aggInfos)
    val aggBufferTypes = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
    val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
    val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)

    val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
    val className = if (isFinal) "HashAggregateWithKeys" else "LocalHashAggregateWithKeys"

    // add logger
    val logTerm = CodeGenUtils.newName(ctx, "LOG")
    ctx.addReusableLogger(logTerm, className)

    // gen code to do group key projection from input
    val currentKeyTerm = CodeGenUtils.newName(ctx, "currentKey")
    val currentKeyWriterTerm = CodeGenUtils.newName(ctx, "currentKeyWriter")
    // currentValueTerm and currentValueWriterTerm are used for value
    // projection while supportAdaptiveLocalHashAgg is true.
    val currentValueTerm = CodeGenUtils.newName(ctx, "currentValue")
    val currentValueWriterTerm = CodeGenUtils.newName(ctx, "currentValueWriter")
    val keyProjectionCode = ProjectionCodeGenerator
      .generateProjectionExpression(
        ctx,
        inputType,
        groupKeyRowType,
        grouping,
        inputTerm = inputTerm,
        outRecordTerm = currentKeyTerm,
        outRecordWriterTerm = currentKeyWriterTerm)
      .code

    // gen code to create groupKey, aggBuffer Type array
    // it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
    val groupKeyTypesTerm = CodeGenUtils.newName(ctx, "groupKeyTypes")
    val aggBufferTypesTerm = CodeGenUtils.newName(ctx, "aggBufferTypes")
    HashAggCodeGenHelper.prepareHashAggKVTypes(
      ctx,
      groupKeyTypesTerm,
      aggBufferTypesTerm,
      groupKeyRowType,
      aggBufferRowType)

    val binaryRowTypeTerm = classOf[BinaryRowData].getName
    // gen code to aggregate and output using hash map
    val aggregateMapTerm = CodeGenUtils.newName(ctx, "aggregateMap")
    val lookupInfoTypeTerm = classOf[BytesMap.LookupInfo[_, _]].getCanonicalName
    val lookupInfo = ctx.addReusableLocalVariable(lookupInfoTypeTerm, "lookupInfo")
    HashAggCodeGenHelper.prepareHashAggMap(
      ctx,
      groupKeyTypesTerm,
      aggBufferTypesTerm,
      aggregateMapTerm)

    val outputTerm = CodeGenUtils.newName(ctx, "hashAggOutput")
    val (reuseGroupKeyTerm, reuseAggBufferTerm) =
      HashAggCodeGenHelper.prepareTermForAggMapIteration(
        ctx,
        outputTerm,
        outputType,
        if (grouping.isEmpty) classOf[GenericRowData] else classOf[JoinedRowData])

    val currentAggBufferTerm = ctx.addReusableLocalVariable(binaryRowTypeTerm, "currentAggBuffer")
    val (initedAggBuffer, aggregate, outputExpr) = HashAggCodeGenHelper.genHashAggCodes(
      isMerge,
      isFinal,
      ctx,
      builder,
      (grouping, auxGrouping),
      inputTerm,
      inputType,
      aggInfos,
      currentAggBufferTerm,
      aggBufferRowType,
      aggBufferTypes,
      outputTerm,
      outputType,
      reuseGroupKeyTerm,
      reuseAggBufferTerm
    )

    val outputResultFromMap = HashAggCodeGenHelper.genAggMapIterationAndOutput(
      ctx,
      isFinal,
      aggregateMapTerm,
      reuseGroupKeyTerm,
      reuseAggBufferTerm,
      outputExpr)

    // gen code to deal with hash map oom, if enable fallback we will use sort agg strategy
    val sorterTerm = CodeGenUtils.newName(ctx, "sorter")
    val retryAppend = HashAggCodeGenHelper.genRetryAppendToMap(
      aggregateMapTerm,
      currentKeyTerm,
      initedAggBuffer,
      lookupInfo,
      currentAggBufferTerm)

    val (dealWithAggHashMapOOM, fallbackToSortAggCode) = HashAggCodeGenHelper.genAggMapOOMHandling(
      isFinal,
      ctx,
      builder,
      (grouping, auxGrouping),
      aggInfos,
      functionIdentifiers,
      logTerm,
      aggregateMapTerm,
      (groupKeyTypesTerm, aggBufferTypesTerm),
      (groupKeyRowType, aggBufferRowType),
      aggBufferPrefix,
      aggBufferNames,
      aggBufferTypes,
      outputTerm,
      outputType,
      outputResultFromMap,
      sorterTerm,
      retryAppend,
      maxNumFileHandles,
      compressionEnabled,
      compressionBlockSize
    )

    HashAggCodeGenHelper.prepareMetrics(ctx, aggregateMapTerm, if (isFinal) sorterTerm else null)

    // Do adaptive hash aggregation
    val outputResultForAdaptiveLocalHashAgg = {
      // gen code to iterating the aggregate map and output to downstream
      val inputUnboxingCode = s"${ctx.reuseInputUnboxingCode(reuseAggBufferTerm)}"
      s"""
         |   // set result and output
         |   $reuseGroupKeyTerm =  ($ROW_DATA)$currentKeyTerm;
         |   $reuseAggBufferTerm = ($ROW_DATA)$currentValueTerm;
         |   $inputUnboxingCode
         |   ${outputExpr.code}
         |   ${OperatorCodeGenerator.generateCollect(outputExpr.resultTerm)}
         |
       """.stripMargin
    }
    val localAggSuppressedTerm = CodeGenUtils.newName(ctx, "localAggSuppressed")
    ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;")
    val valueProjectionCode =
      if (!isFinal && supportAdaptiveLocalHashAgg) {
        ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
          ctx,
          inputType,
          classOf[BinaryRowData],
          inputTerm = inputTerm,
          aggInfos,
          auxGrouping,
          outRecordTerm = currentValueTerm,
          outRecordWriterTerm = currentValueWriterTerm
        )
      } else {
        ""
      }

    val (
      distinctCountIncCode,
      totalCountIncCode,
      adaptiveSamplingCode,
      adaptiveLocalHashAggCode,
      flushResultSuppressEnableCode) = {
      // from these conditions we know that it must be a distinct operation
      if (
        !isFinal &&
        ctx.tableConfig.get(ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_ENABLED) &&
        supportAdaptiveLocalHashAgg
      ) {
        val adaptiveDistinctCountTerm = CodeGenUtils.newName(ctx, "distinctCount")
        val adaptiveTotalCountTerm = CodeGenUtils.newName(ctx, "totalCount")
        ctx.addReusableMember(s"private transient long $adaptiveDistinctCountTerm = 0;")
        ctx.addReusableMember(s"private transient long $adaptiveTotalCountTerm = 0;")

        val samplingThreshold =
          ctx.tableConfig.get(
            ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_SAMPLING_THRESHOLD)
        val distinctValueRateThreshold =
          ctx.tableConfig.get(
            ExecutionConfigOptions.TABLE_EXEC_LOCAL_HASH_AGG_ADAPTIVE_DISTINCT_VALUE_RATE_THRESHOLD)

        (
          s"$adaptiveDistinctCountTerm++;",
          s"$adaptiveTotalCountTerm++;",
          s"""
             |if ($adaptiveTotalCountTerm == $samplingThreshold) {
             |  $logTerm.info("Local hash aggregation checkpoint reached, sampling threshold = " +
             |    $samplingThreshold + ", distinct value count = " + $adaptiveDistinctCountTerm + ", total = " +
             |    $adaptiveTotalCountTerm + ", distinct value rate threshold = " 
             |    + $distinctValueRateThreshold);
             |  if ($adaptiveDistinctCountTerm / (1.0 * $adaptiveTotalCountTerm) > $distinctValueRateThreshold) {
             |    $logTerm.info("Local hash aggregation is suppressed");
             |    $localAggSuppressedTerm = true;
             |  }
             |}
             |""".stripMargin,
          s"""
             |if ($localAggSuppressedTerm) {
             |  $valueProjectionCode
             |  $outputResultForAdaptiveLocalHashAgg
             |  return;
             |}
             |""".stripMargin,
          s"""
             |if ($localAggSuppressedTerm) {
             |  $outputResultFromMap
             |  return;
             |}
             |""".stripMargin)
      } else {
        ("", "", "", "", "")
      }
    }

    val lazyInitAggBufferCode = if (auxGrouping.nonEmpty) {
      s"""
         |// lazy init agg buffer (with auxGrouping)
         |${initedAggBuffer.code}
       """.stripMargin
    } else {
      ""
    }

    val processCode =
      s"""
         | // input field access for group key projection and aggregate buffer update
         |${ctx.reuseInputUnboxingCode(inputTerm)}
         | // project key from input
         |$keyProjectionCode
         |
         |$adaptiveLocalHashAggCode
         |
         | // look up output buffer using current group key
         |$lookupInfo = ($lookupInfoTypeTerm) $aggregateMapTerm.lookup($currentKeyTerm);
         |$currentAggBufferTerm = ($binaryRowTypeTerm) $lookupInfo.getValue();
         |
         |if (!$lookupInfo.isFound()) {
         |  $distinctCountIncCode
         |  $lazyInitAggBufferCode
         |  // append empty agg buffer into aggregate map for current group key
         |  try {
         |    $currentAggBufferTerm =
         |      $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
         |  } catch (java.io.EOFException exp) {
         |    $dealWithAggHashMapOOM
         |  }
         |}
         |
         |$totalCountIncCode
         |$adaptiveSamplingCode
         |
         | // aggregate buffer fields access
         |${ctx.reuseInputUnboxingCode(currentAggBufferTerm)}
         | // do aggregate and update agg buffer
         |${aggregate.code}
         | // flush result form map if suppress is enable. 
         |$flushResultSuppressEnableCode
         |""".stripMargin.trim

    val endInputCode = if (isFinal) {
      val memPoolTypeTerm = classOf[BytesHashMapSpillMemorySegmentPool].getName
      s"""
         |if ($sorterTerm == null) {
         | // no spilling, output by iterating aggregate map.
         | $outputResultFromMap
         |} else {
         |  // spill last part of input' aggregation output buffer
         |  $sorterTerm.sortAndSpill(
         |    $aggregateMapTerm.getRecordAreaMemorySegments(),
         |    $aggregateMapTerm.getNumElements(),
         |    new $memPoolTypeTerm($aggregateMapTerm.getBucketAreaMemorySegments()));
         |   // only release floating memory in advance.
         |   $aggregateMapTerm.free(true);
         |  // fall back to sort based aggregation
         |  $fallbackToSortAggCode
         |}
       """.stripMargin
    } else {
      s"""
         |if (!$localAggSuppressedTerm) {
         | $outputResultFromMap
         |}
         |""".stripMargin
    }

    AggCodeGenHelper.generateOperator(
      ctx,
      className,
      classOf[TableStreamOperator[RowData]].getCanonicalName,
      processCode,
      endInputCode,
      inputType)
  }