override def onMatch()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala [140:409]


  override def onMatch(call: RelOptRuleCall): Unit = {
    val tableConfig = unwrapTableConfig(call)
    val originalAggregate: FlinkLogicalAggregate = call.rel(0)
    val aggCalls = originalAggregate.getAggCallList
    val input: FlinkRelNode = call.rel(1)
    val cluster = originalAggregate.getCluster
    val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder]
    relBuilder.push(input)
    val aggGroupSet = originalAggregate.getGroupSet.toArray

    // STEP 1: add hash fields if necessary
    val hashFieldIndexes: Array[Int] = aggCalls
      .flatMap {
        aggCall =>
          if (SplitAggregateRule.needAddHashFields(aggCall)) {
            SplitAggregateRule.getArgIndexes(aggCall)
          } else {
            Array.empty[Int]
          }
      }
      .distinct
      .diff(aggGroupSet)
      .sorted
      .toArray

    val hashFieldsMap: util.Map[Int, Int] = new util.HashMap()
    val buckets =
      tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_BUCKET_NUM)

    if (hashFieldIndexes.nonEmpty) {
      val projects = new util.ArrayList[RexNode](relBuilder.fields)
      val hashFieldsOffset = projects.size()

      hashFieldIndexes.zipWithIndex.foreach {
        case (hashFieldIdx, index) =>
          val hashField = relBuilder.field(hashFieldIdx)
          // hash(f) % buckets
          val node: RexNode = relBuilder.call(
            SqlStdOperatorTable.MOD,
            relBuilder.call(FlinkSqlOperatorTable.HASH_CODE, hashField),
            relBuilder.literal(buckets))
          projects.add(node)
          hashFieldsMap.put(hashFieldIdx, hashFieldsOffset + index)
      }
      relBuilder.project(projects)
    }

    // STEP 2: construct partial aggregates
    val groupSetTreeSet = new util.TreeSet[ImmutableBitSet](ImmutableBitSet.ORDERING)
    val aggInfoToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
    var newGroupSetsNum = 0
    aggCalls.foreach {
      aggCall =>
        val groupSet = if (SplitAggregateRule.needAddHashFields(aggCall)) {
          val newIndexes = SplitAggregateRule
            .getArgIndexes(aggCall)
            .map(argIndex => hashFieldsMap.getOrElse(argIndex, argIndex).asInstanceOf[Integer])
            .toSeq
          val newGroupSet =
            ImmutableBitSet.of(newIndexes).union(ImmutableBitSet.of(aggGroupSet: _*))
          // Only increment groupSet number if aggregate call needs add new different hash fields
          // e.g SQL1: SELECT COUNT(DISTINCT a), MAX(a) FROM T group by b
          // newGroupSetsNum is 1 because two agg function add same hash field
          // e.g SQL2: SELECT COUNT(DISTINCT a), COUNT(b) FROM T group by c
          // newGroupSetsNum is 1 because only COUNT(DISTINCT a) adds a new hash field
          // e.g SQL3: SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM T group by b
          // newGroupSetsNum is 2 because COUNT(DISTINCT a), COUNT(DISTINCT b) both add hash field
          if (!groupSetTreeSet.contains(newGroupSet)) {
            newGroupSetsNum += 1
          }
          newGroupSet
        } else {
          ImmutableBitSet.of(aggGroupSet: _*)
        }
        groupSetTreeSet.add(groupSet)
        aggInfoToGroupSetMap.put(aggCall, groupSet)
    }
    val groupSets = ImmutableList.copyOf(asJavaIterable(groupSetTreeSet))
    val fullGroupSet = ImmutableBitSet.union(groupSets)

    // STEP 2.1: expand input fields
    val partialAggCalls = new util.ArrayList[AggregateCall]
    val partialAggCallToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
    aggCalls.foreach {
      aggCall =>
        val newAggCalls = SplitAggregateRule.getPartialAggFunction(aggCall).map {
          aggFunc =>
            AggregateCall.create(
              aggFunc,
              aggCall.isDistinct,
              aggCall.isApproximate,
              false,
              aggCall.getArgList,
              aggCall.filterArg,
              null,
              RelCollations.EMPTY,
              fullGroupSet.cardinality,
              relBuilder.peek(),
              null,
              null
            )
        }
        partialAggCalls.addAll(newAggCalls)
        newAggCalls.foreach {
          newAggCall =>
            partialAggCallToGroupSetMap.put(newAggCall, aggInfoToGroupSetMap.get(aggCall))
        }
    }

    val needExpand = newGroupSetsNum > 1
    val duplicateFieldMap = if (needExpand) {
      val (duplicateFieldMap, _) =
        ExpandUtil.buildExpandNode(relBuilder, partialAggCalls, fullGroupSet, groupSets)
      duplicateFieldMap
    } else {
      Map.empty[Integer, Integer]
    }

    // STEP 2.2: add filter columns for partial aggregates
    val filters = new util.LinkedHashMap[(ImmutableBitSet, Integer), Integer]
    val newPartialAggCalls = new util.ArrayList[AggregateCall]
    if (needExpand) {
      // GROUPING returns an integer (0, 1, 2...).
      // Add a project to convert those values to BOOLEAN.
      val nodes = new util.ArrayList[RexNode](relBuilder.fields)
      val expandIdNode = nodes.remove(nodes.size - 1)
      val filterColumnsOffset: Int = nodes.size
      var x: Int = 0
      partialAggCalls.foreach {
        aggCall =>
          val groupSet = partialAggCallToGroupSetMap.get(aggCall)
          val oldFilterArg = aggCall.filterArg
          val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList

          if (!filters.contains(groupSet, oldFilterArg)) {
            val expandId = ExpandUtil.genExpandId(fullGroupSet, groupSet)
            if (oldFilterArg >= 0) {
              nodes.add(
                relBuilder.alias(
                  relBuilder.and(
                    relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
                    relBuilder.field(oldFilterArg)),
                  "$g_" + expandId))
            } else {
              nodes.add(
                relBuilder.alias(
                  relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
                  "$g_" + expandId))
            }
            val newFilterArg = filterColumnsOffset + x
            filters.put((groupSet, oldFilterArg), newFilterArg)
            x += 1
          }

          val newFilterArg = filters((groupSet, oldFilterArg))
          val newAggCall = aggCall.adaptTo(
            relBuilder.peek(),
            newArgList,
            newFilterArg,
            fullGroupSet.cardinality,
            fullGroupSet.cardinality)
          newPartialAggCalls.add(newAggCall)
      }
      relBuilder.project(nodes)
    } else {
      newPartialAggCalls.addAll(partialAggCalls)
    }

    // STEP 2.3: construct partial aggregates
    // Create aggregate node directly to avoid ClassCastException,
    // Please see FLINK-21923 for more details.
    // TODO reuse aggregate function, see FLINK-22412
    val partialAggregate = FlinkLogicalAggregate.create(
      relBuilder.build(),
      fullGroupSet,
      ImmutableList.of[ImmutableBitSet](fullGroupSet),
      newPartialAggCalls,
      originalAggregate.getHints)
    partialAggregate.setPartialFinalType(PartialFinalType.PARTIAL)
    relBuilder.push(partialAggregate)

    // STEP 3: construct final aggregates
    val finalAggInputOffset = fullGroupSet.cardinality
    var x: Int = 0
    val finalAggCalls = new util.ArrayList[AggregateCall]
    var needMergeFinalAggOutput: Boolean = false
    aggCalls.foreach {
      aggCall =>
        val newAggCalls = SplitAggregateRule.getFinalAggFunction(aggCall).map {
          aggFunction =>
            val newArgList = ImmutableIntList.of(finalAggInputOffset + x)
            x += 1

            AggregateCall.create(
              aggFunction,
              false,
              aggCall.isApproximate,
              false,
              newArgList,
              -1,
              null,
              RelCollations.EMPTY,
              originalAggregate.getGroupCount,
              relBuilder.peek(),
              null,
              null)
        }

        finalAggCalls.addAll(newAggCalls)
        if (newAggCalls.size > 1) {
          needMergeFinalAggOutput = true
        }
    }
    // Create aggregate node directly to avoid ClassCastException,
    // Please see FLINK-21923 for more details.
    // TODO reuse aggregate function, see FLINK-22412
    val finalAggregate = FlinkLogicalAggregate.create(
      relBuilder.build(),
      SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
      SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet)),
      finalAggCalls,
      originalAggregate.getHints
    )
    finalAggregate.setPartialFinalType(PartialFinalType.FINAL)
    relBuilder.push(finalAggregate)

    // STEP 4: convert final aggregation output to the original aggregation output.
    // For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of
    // the final aggregation is (sum, count). We should converted it to (sum / count)
    // for the final output.
    val aggGroupCount = finalAggregate.getGroupCount
    if (needMergeFinalAggOutput) {
      val nodes = new util.ArrayList[RexNode]
      (0 until aggGroupCount).foreach {
        index => nodes.add(RexInputRef.of(index, finalAggregate.getRowType))
      }

      var avgAggCount: Int = 0
      aggCalls.zipWithIndex.foreach {
        case (aggCall, index) =>
          val newNode = if (aggCall.getAggregation.getKind == SqlKind.AVG) {
            val sumInputRef =
              RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
            val countInputRef =
              RexInputRef.of(aggGroupCount + index + avgAggCount + 1, finalAggregate.getRowType)
            avgAggCount += 1
            // Make a guarantee that the final aggregation returns NULL if underlying count is ZERO.
            // We use SUM0 for underlying sum, which may run into ZERO / ZERO,
            // and division by zero exception occurs.
            // @see Glossary#SQL2011 SQL:2011 Part 2 Section 6.27
            val equals = relBuilder.call(
              FlinkSqlOperatorTable.EQUALS,
              countInputRef,
              relBuilder.getRexBuilder.makeBigintLiteral(JBigDecimal.valueOf(0)))
            val ifTrue = relBuilder.getRexBuilder.makeNullLiteral(aggCall.`type`)
            val ifFalse = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
            relBuilder.call(FlinkSqlOperatorTable.IF, equals, ifTrue, ifFalse)
          } else {
            RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
          }
          nodes.add(newNode)
      }
      relBuilder.project(nodes)
    }

    relBuilder.convert(originalAggregate.getRowType, false)

    val newRel = relBuilder.build()
    call.transformTo(newRel)
  }