private void rewriteUsingGroupingSets()

in flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateExpandDistinctAggregatesRule.java [443:620]


    private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
        final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
        final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                groupSetTreeSet.add(aggregate.getGroupSet());
            } else {
                ImmutableBitSet groupSet =
                        ImmutableBitSet.of(aggCall.getArgList())
                                .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
                                .union(aggregate.getGroupSet());
                groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
                groupSetTreeSet.add(groupSet);
            }
        }

        final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets =
                com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
        final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);

        final List<AggregateCall> distinctAggCalls = new ArrayList<>();
        for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
            if (!aggCall.left.isDistinct()) {
                AggregateCall newAggCall =
                        aggCall.left.adaptTo(
                                aggregate.getInput(),
                                aggCall.left.getArgList(),
                                aggCall.left.filterArg,
                                aggregate.getGroupCount(),
                                fullGroupSet.cardinality());
                distinctAggCalls.add(newAggCall.withName(aggCall.right));
            }
        }

        final RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        final int groupCount = fullGroupSet.cardinality();

        final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
        final int z = groupCount + distinctAggCalls.size();
        distinctAggCalls.add(
                AggregateCall.create(
                        SqlStdOperatorTable.GROUPING,
                        false,
                        false,
                        false,
                        ImmutableIntList.copyOf(fullGroupSet),
                        -1,
                        null,
                        RelCollations.EMPTY,
                        groupSets.size(),
                        relBuilder.peek(),
                        null,
                        "$g"));
        for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
            filters.put(groupSet.e, z + groupSet.i);
        }

        relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), distinctAggCalls);
        final RelNode distinct = relBuilder.peek();

        // GROUPING returns an integer (0 or 1). Add a project to convert those
        // values to BOOLEAN.
        if (!filters.isEmpty()) {
            final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
            final RexNode nodeZ = nodes.remove(nodes.size() - 1);
            for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
                final long v = groupValue(fullGroupSet, entry.getKey());
                // Get and remap the filterArg of the distinct aggregate call.
                int distinctAggCallFilterArg =
                        remap(
                                fullGroupSet,
                                groupSetToDistinctAggCallFilterArg.getOrDefault(
                                        entry.getKey(), -1));
                RexNode expr;
                if (distinctAggCallFilterArg < 0) {
                    expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
                } else {
                    RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                    // merge the filter of the distinct aggregate call itself.
                    expr =
                            relBuilder.and(
                                    relBuilder.equals(nodeZ, relBuilder.literal(v)),
                                    rexBuilder.makeCall(
                                            SqlStdOperatorTable.IS_TRUE,
                                            relBuilder.field(distinctAggCallFilterArg)));
                }
                nodes.add(relBuilder.alias(expr, "$g_" + v));
            }
            relBuilder.project(nodes);
        }

        int aggCallIdx = 0;
        int x = groupCount;
        final List<AggregateCall> newCalls = new ArrayList<>();
        // TODO supports more aggCalls (currently only supports COUNT)
        // Some aggregate functions (e.g. COUNT) have the special property that they can return a
        // non-null result without any input. We need to make sure we return a result in this case.
        final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            final int newFilterArg;
            final List<Integer> newArgList;
            final SqlAggFunction aggregation;
            if (!aggCall.isDistinct()) {
                aggregation = SqlStdOperatorTable.MIN;
                newArgList = ImmutableIntList.of(x++);
                newFilterArg = filters.get(aggregate.getGroupSet());
                switch (aggCall.getAggregation().getKind()) {
                    case COUNT:
                        needDefaultValueAggCalls.add(aggCallIdx);
                        break;
                    default:
                }
            } else {
                aggregation = aggCall.getAggregation();
                newArgList = remap(fullGroupSet, aggCall.getArgList());
                newFilterArg =
                        filters.get(
                                ImmutableBitSet.of(aggCall.getArgList())
                                        .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
                                        .union(aggregate.getGroupSet()));
            }
            final AggregateCall newCall =
                    AggregateCall.create(
                            aggregation,
                            false,
                            aggCall.isApproximate(),
                            false,
                            newArgList,
                            newFilterArg,
                            aggCall.distinctKeys,
                            RelCollations.EMPTY,
                            aggregate.getGroupCount(),
                            distinct,
                            null,
                            aggCall.name);
            newCalls.add(newCall);
            aggCallIdx++;
        }

        relBuilder.aggregate(
                relBuilder.groupKey(
                        remap(fullGroupSet, aggregate.getGroupSet()),
                        remap(fullGroupSet, aggregate.getGroupSets())),
                newCalls);
        if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
            final Aggregate newAgg = (Aggregate) relBuilder.peek();
            final List<RexNode> nodes = new ArrayList<>();
            for (int i = 0; i < newAgg.getGroupCount(); ++i) {
                nodes.add(RexInputRef.of(i, newAgg.getRowType()));
            }
            for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
                final RexNode inputRef =
                        RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
                RexNode newNode = inputRef;
                if (needDefaultValueAggCalls.contains(i)) {
                    SqlKind originalFunKind =
                            aggregate.getAggCallList().get(i).getAggregation().getKind();
                    switch (originalFunKind) {
                        case COUNT:
                            newNode =
                                    relBuilder.call(
                                            SqlStdOperatorTable.CASE,
                                            relBuilder.isNotNull(inputRef),
                                            inputRef,
                                            relBuilder.literal(BigDecimal.ZERO));
                            break;
                        default:
                    }
                }
                nodes.add(newNode);
            }
            relBuilder.project(nodes);
        }

        relBuilder.convert(aggregate.getRowType(), true);
        call.transformTo(relBuilder.build());
    }