public List buildRules()

in fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java [109:443]


    public List<Rule> buildRules() {
        PatternDescriptor<LogicalAggregate<GroupPlan>> basePattern = logicalAggregate();

        return ImmutableList.of(
            RuleType.COUNT_ON_INDEX_WITHOUT_PROJECT.build(
                logicalAggregate(
                    logicalFilter(
                        logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
                    )
                )
                .when(agg -> enablePushDownCountOnIndex())
                .when(agg -> agg.getGroupByExpressions().isEmpty())
                .when(agg -> {
                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                    if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
                            && (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
                        return false;
                    }
                    Set<Expression> conjuncts = agg.child().getConjuncts();
                    if (conjuncts.isEmpty()) {
                        return false;
                    }

                    Set<Slot> aggSlots = funcs.stream()
                            .flatMap(f -> f.getInputSlots().stream())
                            .collect(Collectors.toSet());
                    return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
                                checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
                })
                .thenApply(ctx -> {
                    LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
                    LogicalFilter<LogicalOlapScan> filter = agg.child();
                    LogicalOlapScan olapScan = filter.child();
                    return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext);
                })
            ),
            RuleType.COUNT_ON_INDEX.build(
                logicalAggregate(
                    logicalProject(
                        logicalFilter(
                            logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
                        )
                    )
                )
                .when(agg -> enablePushDownCountOnIndex())
                .when(agg -> agg.getGroupByExpressions().isEmpty())
                .when(agg -> {
                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                    if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
                            && (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
                        return false;
                    }
                    Set<Expression> conjuncts = agg.child().child().getConjuncts();
                    if (conjuncts.isEmpty()) {
                        return false;
                    }

                    Set<Slot> aggSlots = funcs.stream()
                            .flatMap(f -> f.getInputSlots().stream())
                            .collect(Collectors.toSet());
                    return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
                                checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
                })
                .thenApply(ctx -> {
                    LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                    LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
                    LogicalFilter<LogicalOlapScan> filter = project.child();
                    LogicalOlapScan olapScan = filter.child();
                    return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
                })
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
                logicalAggregate(
                        logicalFilter(
                                logicalOlapScan().when(this::isUniqueKeyTable))
                                .when(filter -> {
                                    if (filter.getConjuncts().size() != 1) {
                                        return false;
                                    }
                                    Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
                                    if (childExpr instanceof SlotReference) {
                                        Optional<Column> column = ((SlotReference) childExpr).getColumn();
                                        return column.map(Column::isDeleteSignColumn).orElse(false);
                                    }
                                    return false;
                                })
                        )
                        .when(agg -> enablePushDownMinMaxOnUnique())
                        .when(agg -> agg.getGroupByExpressions().isEmpty())
                        .when(agg -> {
                            Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                            return !funcs.isEmpty() && funcs.stream()
                                    .allMatch(f -> (f instanceof Min) || (f instanceof Max));
                        })
                        .thenApply(ctx -> {
                            LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
                            LogicalFilter<LogicalOlapScan> filter = agg.child();
                            LogicalOlapScan olapScan = filter.child();
                            return pushdownMinMaxOnUniqueTable(agg, null, filter, olapScan,
                                    ctx.cascadesContext);
                        })
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE.build(
                    logicalAggregate(
                            logicalProject(
                                    logicalFilter(
                                            logicalOlapScan().when(this::isUniqueKeyTable))
                                            .when(filter -> {
                                                if (filter.getConjuncts().size() != 1) {
                                                    return false;
                                                }
                                                Expression childExpr = filter.getConjuncts().iterator().next()
                                                        .children().get(0);
                                                if (childExpr instanceof SlotReference) {
                                                    Optional<Column> column = ((SlotReference) childExpr).getColumn();
                                                    return column.map(Column::isDeleteSignColumn).orElse(false);
                                                }
                                                return false;
                                            }))
                        )
                        .when(agg -> enablePushDownMinMaxOnUnique())
                        .when(agg -> agg.getGroupByExpressions().isEmpty())
                        .when(agg -> {
                            Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                            return !funcs.isEmpty()
                                    && funcs.stream().allMatch(f -> (f instanceof Min) || (f instanceof Max));
                        })
                        .thenApply(ctx -> {
                            LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
                            LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
                            LogicalFilter<LogicalOlapScan> filter = project.child();
                            LogicalOlapScan olapScan = filter.child();
                            return pushdownMinMaxOnUniqueTable(agg, project, filter, olapScan,
                                    ctx.cascadesContext);
                        })
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
                logicalAggregate(
                    logicalOlapScan()
                )
                .when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
                .thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
            ),
            RuleType.STORAGE_LAYER_WITH_PROJECT_NO_SLOT_REF.build(
                    logicalProject(
                            logicalOlapScan()
                    )
                    .thenApply(ctx -> {
                        LogicalProject<LogicalOlapScan> project = ctx.root;
                        LogicalOlapScan olapScan = project.child();
                        return pushDownCountWithoutSlotRef(project, olapScan, ctx.cascadesContext);
                    })
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT.build(
                logicalAggregate(
                    logicalProject(
                        logicalOlapScan()
                    )
                )
                .when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
                .thenApply(ctx -> {
                    LogicalAggregate<LogicalProject<LogicalOlapScan>> agg = ctx.root;
                    LogicalProject<LogicalOlapScan> project = agg.child();
                    LogicalOlapScan olapScan = project.child();
                    return storageLayerAggregate(agg, project, olapScan, ctx.cascadesContext);
                })
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT_FOR_FILE_SCAN.build(
                logicalAggregate(
                    logicalFileScan()
                )
                    .when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
                    .thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
            ),
            RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT_FOR_FILE_SCAN.build(
                logicalAggregate(
                    logicalProject(
                        logicalFileScan()
                    )
                ).when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
                    .thenApply(ctx -> {
                        LogicalAggregate<LogicalProject<LogicalFileScan>> agg = ctx.root;
                        LogicalProject<LogicalFileScan> project = agg.child();
                        LogicalFileScan fileScan = project.child();
                        return storageLayerAggregate(agg, project, fileScan, ctx.cascadesContext);
                    })
            ),
            RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().isEmpty())
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
                    .thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
            ),
            RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().isEmpty())
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
                    .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
            ),
            // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
            //     basePattern
            //         .when(this::containsCountDistinctMultiExpr)
            //         .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
            //         .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
            // ),
            RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
                basePattern
                    .when(this::containsCountDistinctMultiExpr)
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
                    .thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
            ),
            RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
                    .thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
            ),
            RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
                    .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
            ),
            RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().size() > 1
                            && !containsCountDistinctMultiExpr(agg)
                            && couldConvertToMulti(agg))
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
                    .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
            ),
            // RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
            //     basePattern
            //         .when(agg -> agg.getDistinctArguments().size() == 1)
            //         .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
            //         .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
            // ),
            RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().size() == 1)
                    .whenNot(agg -> agg.mustUseMultiDistinctAgg())
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
                    .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
            ),
            /*
             * sql:
             * select count(distinct name), sum(age) from student;
             * <p>
             * 4 phase plan
             * DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
             *          output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
             *          GATHER)
             * +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
             *          output(partial_count(name), partial_sum(partial_sum(age))),
             *          hash distribute by name)
             *    +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
             *          output(name, partial_sum(age)),
             *          hash_distribute by name)
             *       +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
             *          +--scan(name, age)
             */
            RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
                basePattern
                    .when(agg -> agg.getDistinctArguments().size() == 1)
                    .when(agg -> agg.getGroupByExpressions().isEmpty())
                    .whenNot(agg -> agg.mustUseMultiDistinctAgg())
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
                    .thenApplyMulti(ctx -> {
                        Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
                                groupByAndDistinct -> RequireProperties.of(
                                        PhysicalProperties.createHash(
                                                ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
                                        )
                                );
                        Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
                                agg -> RequireProperties.of(PhysicalProperties.GATHER);
                        return fourPhaseAggregateWithDistinct(
                                ctx.root, ctx.connectContext,
                                secondPhaseRequireDistinctHash, fourPhaseRequireGather
                        );
                    })
            ),
            /*
             * sql:
             * select age, count(distinct name) from student group by age;
             * <p>
             * 4 phase plan
             * DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
             *          output[age, sum(partial_count(name))],
             *          hash distribute by name)
             * +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
             *          output(age, partial_count(name)),
             *          hash distribute by age, name)
             *    +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
             *          output(age, name),
             *          hash_distribute by age, name)
             *       +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
             *          +--scan(age, name)
             */
            RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
                basePattern
                    .when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
                    .when(agg ->
                        ImmutableSet.builder()
                            .addAll(agg.getGroupByExpressions())
                            .addAll(agg.getDistinctArguments())
                            .build().size() > agg.getGroupByExpressions().size()
                    )
                    .when(agg -> {
                        if (agg.getDistinctArguments().size() == 1) {
                            return true;
                        }
                        return couldConvertToMulti(agg);
                    })
                    .when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
                    .thenApplyMulti(ctx -> {
                        Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
                                groupByAndDistinct -> RequireProperties.of(
                                        PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
                                );

                        Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
                                agg -> RequireProperties.of(
                                        PhysicalProperties.createHash(
                                                agg.getGroupByExpressions(), ShuffleType.REQUIRE
                                        )
                                );
                        return fourPhaseAggregateWithDistinct(
                                ctx.root, ctx.connectContext,
                                secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
                        );
                    })
            )
        );
    }