private static void rewriteUsingGroupingSets()

in core/src/main/java/org/apache/calcite/rel/rules/AggregateExpandDistinctAggregatesRule.java [428:561]


  private static void rewriteUsingGroupingSets(RelOptRuleCall call,
      Aggregate aggregate) {
    final Set<ImmutableBitSet> groupSetTreeSet =
        new TreeSet<>(ImmutableBitSet.ORDERING);
    // GroupSet to distinct filter arg map,
    // filterArg will be -1 for non-distinct agg call.

    // Using `Set` here because it's possible that two agg calls
    // have different filterArgs but same groupSet.
    final Map<ImmutableBitSet, Set<Integer>> distinctFilterArgMap = new HashMap<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      ImmutableBitSet groupSet;
      int filterArg;
      if (!aggCall.isDistinct()) {
        filterArg = -1;
        groupSet = aggregate.getGroupSet();
        groupSetTreeSet.add(aggregate.getGroupSet());
      } else {
        filterArg = aggCall.filterArg;
        groupSet =
            ImmutableBitSet.of(aggCall.getArgList())
                .setIf(filterArg, filterArg >= 0)
                .union(aggregate.getGroupSet());
        groupSetTreeSet.add(groupSet);
      }
      Set<Integer> filterList = distinctFilterArgMap
          .computeIfAbsent(groupSet, g -> new HashSet<>());
      filterList.add(filterArg);
    }

    final ImmutableList<ImmutableBitSet> groupSets =
        ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);

    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
      if (!aggCall.left.isDistinct()) {
        AggregateCall newAggCall =
            aggCall.left.adaptTo(aggregate.getInput(),
                aggCall.left.getArgList(), aggCall.left.filterArg,
                aggregate.getGroupCount(), fullGroupSet.cardinality());
        distinctAggCalls.add(newAggCall.withName(aggCall.right));
      }
    }

    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    final int groupCount = fullGroupSet.cardinality();

    // Get the base ordinal of filter args for different groupSets.
    final Map<Pair<ImmutableBitSet, Integer>, Integer> filters = new LinkedHashMap<>();
    int z = groupCount + distinctAggCalls.size();
    for (ImmutableBitSet groupSet : groupSets) {
      Set<Integer> filterArgList = distinctFilterArgMap.get(groupSet);
      for (Integer filterArg : requireNonNull(filterArgList, "filterArgList")) {
        filters.put(Pair.of(groupSet, filterArg), z);
        z += 1;
      }
    }

    distinctAggCalls.add(
        AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false,
            ImmutableList.of(), ImmutableIntList.copyOf(fullGroupSet), -1,
            null, RelCollations.EMPTY,
            groupSets.size(), relBuilder.peek(), null, "$g"));

    relBuilder.aggregate(
        relBuilder.groupKey(fullGroupSet, groupSets),
        distinctAggCalls);

    // GROUPING returns an integer (0 or 1). Add a project to convert those
    // values to BOOLEAN.
    if (!filters.isEmpty()) {
      final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
      final RexNode nodeZ = nodes.remove(nodes.size() - 1);
      for (Map.Entry<Pair<ImmutableBitSet, Integer>, Integer> entry : filters.entrySet()) {
        final long v = groupValue(fullGroupSet.asList(), entry.getKey().left);
        int distinctFilterArg = remap(fullGroupSet, entry.getKey().right);
        RexNode expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
        if (distinctFilterArg > -1) {
          // 'AND' the filter of the distinct aggregate call and the group value.
          expr =
              relBuilder.and(expr,
                  relBuilder.call(SqlStdOperatorTable.IS_TRUE,
                      relBuilder.field(distinctFilterArg)));
        }
        // "f" means filter.
        nodes.add(
            relBuilder.alias(expr,
            "$g_" + v + (distinctFilterArg < 0 ? "" : "_f_" + distinctFilterArg)));
      }
      relBuilder.project(nodes);
    }

    int x = groupCount;
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      final int newFilterArg;
      final List<Integer> newArgList;
      final SqlAggFunction aggregation;
      if (!aggCall.isDistinct()) {
        aggregation = SqlStdOperatorTable.MIN;
        newArgList = ImmutableIntList.of(x++);
        newFilterArg =
            requireNonNull(filters.get(Pair.of(groupSet, -1)),
                "filters.get(Pair.of(groupSet, -1))");
      } else {
        aggregation = aggCall.getAggregation();
        newArgList = remap(fullGroupSet, aggCall.getArgList());
        final ImmutableBitSet newGroupSet = ImmutableBitSet.of(aggCall.getArgList())
            .setIf(aggCall.filterArg, aggCall.filterArg >= 0)
            .union(groupSet);
        newFilterArg =
            requireNonNull(filters.get(Pair.of(newGroupSet, aggCall.filterArg)),
                "filters.get(of(newGroupSet, aggCall.filterArg))");
      }
      final AggregateCall newCall =
          AggregateCall.create(aggCall.getParserPosition(), aggregation, false,
              aggCall.isApproximate(), aggCall.ignoreNulls(),
              aggCall.rexList, newArgList, newFilterArg,
              aggCall.distinctKeys, aggCall.collation,
              aggregate.getGroupCount(), relBuilder.peek(), null, aggCall.name);
      newCalls.add(newCall);
    }

    relBuilder.aggregate(
        relBuilder.groupKey(
            remap(fullGroupSet, groupSet),
            remap(fullGroupSet, aggregate.getGroupSets())),
        newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
  }