in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala [140:409]
override def onMatch(call: RelOptRuleCall): Unit = {
val tableConfig = unwrapTableConfig(call)
val originalAggregate: FlinkLogicalAggregate = call.rel(0)
val aggCalls = originalAggregate.getAggCallList
val input: FlinkRelNode = call.rel(1)
val cluster = originalAggregate.getCluster
val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder]
relBuilder.push(input)
val aggGroupSet = originalAggregate.getGroupSet.toArray
// STEP 1: add hash fields if necessary
val hashFieldIndexes: Array[Int] = aggCalls
.flatMap {
aggCall =>
if (SplitAggregateRule.needAddHashFields(aggCall)) {
SplitAggregateRule.getArgIndexes(aggCall)
} else {
Array.empty[Int]
}
}
.distinct
.diff(aggGroupSet)
.sorted
.toArray
val hashFieldsMap: util.Map[Int, Int] = new util.HashMap()
val buckets =
tableConfig.get(OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_BUCKET_NUM)
if (hashFieldIndexes.nonEmpty) {
val projects = new util.ArrayList[RexNode](relBuilder.fields)
val hashFieldsOffset = projects.size()
hashFieldIndexes.zipWithIndex.foreach {
case (hashFieldIdx, index) =>
val hashField = relBuilder.field(hashFieldIdx)
// hash(f) % buckets
val node: RexNode = relBuilder.call(
SqlStdOperatorTable.MOD,
relBuilder.call(FlinkSqlOperatorTable.HASH_CODE, hashField),
relBuilder.literal(buckets))
projects.add(node)
hashFieldsMap.put(hashFieldIdx, hashFieldsOffset + index)
}
relBuilder.project(projects)
}
// STEP 2: construct partial aggregates
val groupSetTreeSet = new util.TreeSet[ImmutableBitSet](ImmutableBitSet.ORDERING)
val aggInfoToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
var newGroupSetsNum = 0
aggCalls.foreach {
aggCall =>
val groupSet = if (SplitAggregateRule.needAddHashFields(aggCall)) {
val newIndexes = SplitAggregateRule
.getArgIndexes(aggCall)
.map(argIndex => hashFieldsMap.getOrElse(argIndex, argIndex).asInstanceOf[Integer])
.toSeq
val newGroupSet =
ImmutableBitSet.of(newIndexes).union(ImmutableBitSet.of(aggGroupSet: _*))
// Only increment groupSet number if aggregate call needs add new different hash fields
// e.g SQL1: SELECT COUNT(DISTINCT a), MAX(a) FROM T group by b
// newGroupSetsNum is 1 because two agg function add same hash field
// e.g SQL2: SELECT COUNT(DISTINCT a), COUNT(b) FROM T group by c
// newGroupSetsNum is 1 because only COUNT(DISTINCT a) adds a new hash field
// e.g SQL3: SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM T group by b
// newGroupSetsNum is 2 because COUNT(DISTINCT a), COUNT(DISTINCT b) both add hash field
if (!groupSetTreeSet.contains(newGroupSet)) {
newGroupSetsNum += 1
}
newGroupSet
} else {
ImmutableBitSet.of(aggGroupSet: _*)
}
groupSetTreeSet.add(groupSet)
aggInfoToGroupSetMap.put(aggCall, groupSet)
}
val groupSets = ImmutableList.copyOf(asJavaIterable(groupSetTreeSet))
val fullGroupSet = ImmutableBitSet.union(groupSets)
// STEP 2.1: expand input fields
val partialAggCalls = new util.ArrayList[AggregateCall]
val partialAggCallToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
aggCalls.foreach {
aggCall =>
val newAggCalls = SplitAggregateRule.getPartialAggFunction(aggCall).map {
aggFunc =>
AggregateCall.create(
aggFunc,
aggCall.isDistinct,
aggCall.isApproximate,
false,
aggCall.getArgList,
aggCall.filterArg,
null,
RelCollations.EMPTY,
fullGroupSet.cardinality,
relBuilder.peek(),
null,
null
)
}
partialAggCalls.addAll(newAggCalls)
newAggCalls.foreach {
newAggCall =>
partialAggCallToGroupSetMap.put(newAggCall, aggInfoToGroupSetMap.get(aggCall))
}
}
val needExpand = newGroupSetsNum > 1
val duplicateFieldMap = if (needExpand) {
val (duplicateFieldMap, _) =
ExpandUtil.buildExpandNode(relBuilder, partialAggCalls, fullGroupSet, groupSets)
duplicateFieldMap
} else {
Map.empty[Integer, Integer]
}
// STEP 2.2: add filter columns for partial aggregates
val filters = new util.LinkedHashMap[(ImmutableBitSet, Integer), Integer]
val newPartialAggCalls = new util.ArrayList[AggregateCall]
if (needExpand) {
// GROUPING returns an integer (0, 1, 2...).
// Add a project to convert those values to BOOLEAN.
val nodes = new util.ArrayList[RexNode](relBuilder.fields)
val expandIdNode = nodes.remove(nodes.size - 1)
val filterColumnsOffset: Int = nodes.size
var x: Int = 0
partialAggCalls.foreach {
aggCall =>
val groupSet = partialAggCallToGroupSetMap.get(aggCall)
val oldFilterArg = aggCall.filterArg
val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList
if (!filters.contains(groupSet, oldFilterArg)) {
val expandId = ExpandUtil.genExpandId(fullGroupSet, groupSet)
if (oldFilterArg >= 0) {
nodes.add(
relBuilder.alias(
relBuilder.and(
relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
relBuilder.field(oldFilterArg)),
"$g_" + expandId))
} else {
nodes.add(
relBuilder.alias(
relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
"$g_" + expandId))
}
val newFilterArg = filterColumnsOffset + x
filters.put((groupSet, oldFilterArg), newFilterArg)
x += 1
}
val newFilterArg = filters((groupSet, oldFilterArg))
val newAggCall = aggCall.adaptTo(
relBuilder.peek(),
newArgList,
newFilterArg,
fullGroupSet.cardinality,
fullGroupSet.cardinality)
newPartialAggCalls.add(newAggCall)
}
relBuilder.project(nodes)
} else {
newPartialAggCalls.addAll(partialAggCalls)
}
// STEP 2.3: construct partial aggregates
// Create aggregate node directly to avoid ClassCastException,
// Please see FLINK-21923 for more details.
// TODO reuse aggregate function, see FLINK-22412
val partialAggregate = FlinkLogicalAggregate.create(
relBuilder.build(),
fullGroupSet,
ImmutableList.of[ImmutableBitSet](fullGroupSet),
newPartialAggCalls,
originalAggregate.getHints)
partialAggregate.setPartialFinalType(PartialFinalType.PARTIAL)
relBuilder.push(partialAggregate)
// STEP 3: construct final aggregates
val finalAggInputOffset = fullGroupSet.cardinality
var x: Int = 0
val finalAggCalls = new util.ArrayList[AggregateCall]
var needMergeFinalAggOutput: Boolean = false
aggCalls.foreach {
aggCall =>
val newAggCalls = SplitAggregateRule.getFinalAggFunction(aggCall).map {
aggFunction =>
val newArgList = ImmutableIntList.of(finalAggInputOffset + x)
x += 1
AggregateCall.create(
aggFunction,
false,
aggCall.isApproximate,
false,
newArgList,
-1,
null,
RelCollations.EMPTY,
originalAggregate.getGroupCount,
relBuilder.peek(),
null,
null)
}
finalAggCalls.addAll(newAggCalls)
if (newAggCalls.size > 1) {
needMergeFinalAggOutput = true
}
}
// Create aggregate node directly to avoid ClassCastException,
// Please see FLINK-21923 for more details.
// TODO reuse aggregate function, see FLINK-22412
val finalAggregate = FlinkLogicalAggregate.create(
relBuilder.build(),
SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet)),
finalAggCalls,
originalAggregate.getHints
)
finalAggregate.setPartialFinalType(PartialFinalType.FINAL)
relBuilder.push(finalAggregate)
// STEP 4: convert final aggregation output to the original aggregation output.
// For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of
// the final aggregation is (sum, count). We should converted it to (sum / count)
// for the final output.
val aggGroupCount = finalAggregate.getGroupCount
if (needMergeFinalAggOutput) {
val nodes = new util.ArrayList[RexNode]
(0 until aggGroupCount).foreach {
index => nodes.add(RexInputRef.of(index, finalAggregate.getRowType))
}
var avgAggCount: Int = 0
aggCalls.zipWithIndex.foreach {
case (aggCall, index) =>
val newNode = if (aggCall.getAggregation.getKind == SqlKind.AVG) {
val sumInputRef =
RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
val countInputRef =
RexInputRef.of(aggGroupCount + index + avgAggCount + 1, finalAggregate.getRowType)
avgAggCount += 1
// Make a guarantee that the final aggregation returns NULL if underlying count is ZERO.
// We use SUM0 for underlying sum, which may run into ZERO / ZERO,
// and division by zero exception occurs.
// @see Glossary#SQL2011 SQL:2011 Part 2 Section 6.27
val equals = relBuilder.call(
FlinkSqlOperatorTable.EQUALS,
countInputRef,
relBuilder.getRexBuilder.makeBigintLiteral(JBigDecimal.valueOf(0)))
val ifTrue = relBuilder.getRexBuilder.makeNullLiteral(aggCall.`type`)
val ifFalse = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
relBuilder.call(FlinkSqlOperatorTable.IF, equals, ifTrue, ifFalse)
} else {
RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
}
nodes.add(newNode)
}
relBuilder.project(nodes)
}
relBuilder.convert(originalAggregate.getRowType, false)
val newRel = relBuilder.build()
call.transformTo(newRel)
}