in core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java [428:561]
private static void rewriteUsingGroupingSets(RelOptRuleCall call,
Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
// GroupSet to distinct filter arg map,
// filterArg will be -1 for non-distinct agg call.
// Using `Set` here because it's possible that two agg calls
// have different filterArgs but same groupSet.
final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
ImmutableBitSet groupSet;
int filterArg;
if (!aggCall.isDistinct()) {
filterArg = -1;
groupSet = aggregate.getGroupSet();
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
filterArg = aggCall.filterArg;
groupSet =
ImmutableBitSet.of(aggCall.getArgList())
.setIf(filterArg, filterArg >= 0)
.union(aggregate.getGroupSet());
groupSetTreeSet.add(groupSet);
}
Set<Integer> filterList = distinctFilterArgMap
.computeIfAbsent(groupSet, g -> new HashSet<>());
filterList.add(filterArg);
}
final ImmutableList<ImmutableBitSet> groupSets =
ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall =
aggCall.left.adaptTo(aggregate.getInput(),
aggCall.left.getArgList(), aggCall.left.filterArg,
aggregate.getGroupCount(), fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.withName(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
// Get the base ordinal of filter args for different groupSets.
final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new LinkedHashMap<>();
int z = groupCount + distinctAggCalls.size();
for (ImmutableBitSet groupSet : groupSets) {
Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
for (Integer filterArg : requireNonNull(filterArgList, "filterArgList")) {
filters.put(Pair.of(groupSet, filterArg), z);
z += 1;
}
}
distinctAggCalls.add(
AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
ImmutableList.of(), ImmutableIntList.copyOf(fullGroupSet), -1,
null, RelCollations.EMPTY,
groupSets.size(), relBuilder.peek(), null, "$g"));
relBuilder.aggregate(
relBuilder.groupKey(fullGroupSet, groupSets),
distinctAggCalls);
// GROUPING returns an integer (0 or 1). Add a project to convert those
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet.asList(), entry.getKey().left);
int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
if (distinctFilterArg > -1) {
// 'AND' the filter of the distinct aggregate call and the group value.
expr =
relBuilder.and(expr,
relBuilder.call(SqlStdOperatorTable.IS_TRUE,
relBuilder.field(distinctFilterArg)));
}
// "f" means filter.
nodes.add(
relBuilder.alias(expr,
"$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + distinctFilterArg)));
}
relBuilder.project(nodes);
}
int x = groupCount;
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final List<AggregateCall> newCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg =
requireNonNull(filters.get(Pair.of(groupSet, -1)),
"filters.get(Pair.of(groupSet, -1))");
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
final ImmutableBitSet newGroupSet = ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(groupSet);
newFilterArg =
requireNonNull(filters.get(Pair.of(newGroupSet, aggCall.filterArg)),
"filters.get(of(newGroupSet, aggCall.filterArg))");
}
final AggregateCall newCall =
AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
aggCall.isApproximate(), aggCall.ignoreNulls(),
aggCall.rexList, newArgList, newFilterArg,
aggCall.distinctKeys, aggCall.collation,
aggregate.getGroupCount(), relBuilder.peek(), null, aggCall.name);
newCalls.add(newCall);
}
relBuilder.aggregate(
relBuilder.groupKey(
remap(fullGroupSet, groupSet),
remap(fullGroupSet, aggregate.getGroupSets())),
newCalls);
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}