in fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java [109:443]
public List<Rule> buildRules() {
PatternDescriptor<LogicalAggregate<GroupPlan>> basePattern = logicalAggregate();
return ImmutableList.of(
RuleType.COUNT_ON_INDEX_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}
Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.COUNT_ON_INDEX.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
)
)
)
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.isEmpty() || !funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot))) {
return false;
}
Set<Expression> conjuncts = agg.child().child().getConjuncts();
if (conjuncts.isEmpty()) {
return false;
}
Set<Slot> aggSlots = funcs.stream()
.flatMap(f -> f.getInputSlots().stream())
.collect(Collectors.toSet());
return aggSlots.isEmpty() || conjuncts.stream().allMatch(expr ->
checkSlotInOrExpression(expr, aggSlots) && checkIsNullExpr(expr, aggSlots));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
})
)
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
LogicalFilter<LogicalOlapScan> filter = agg.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, null, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE.build(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isUniqueKeyTable))
.when(filter -> {
if (filter.getConjuncts().size() != 1) {
return false;
}
Expression childExpr = filter.getConjuncts().iterator().next()
.children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
}))
)
.when(agg -> enablePushDownMinMaxOnUnique())
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty()
&& funcs.stream().allMatch(f -> (f instanceof Min) || (f instanceof Max));
})
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan olapScan = filter.child();
return pushdownMinMaxOnUniqueTable(agg, project, filter, olapScan,
ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
logicalAggregate(
logicalOlapScan()
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
),
RuleType.STORAGE_LAYER_WITH_PROJECT_NO_SLOT_REF.build(
logicalProject(
logicalOlapScan()
)
.thenApply(ctx -> {
LogicalProject<LogicalOlapScan> project = ctx.root;
LogicalOlapScan olapScan = project.child();
return pushDownCountWithoutSlotRef(project, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT.build(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalOlapScan>> agg = ctx.root;
LogicalProject<LogicalOlapScan> project = agg.child();
LogicalOlapScan olapScan = project.child();
return storageLayerAggregate(agg, project, olapScan, ctx.cascadesContext);
})
),
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT_FOR_FILE_SCAN.build(
logicalAggregate(
logicalFileScan()
)
.when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> storageLayerAggregate(ctx.root, null, ctx.root.child(), ctx.cascadesContext))
),
RuleType.STORAGE_LAYER_AGGREGATE_WITH_PROJECT_FOR_FILE_SCAN.build(
logicalAggregate(
logicalProject(
logicalFileScan()
)
).when(agg -> agg.isNormalized() && enablePushDownNoGroupAgg())
.thenApply(ctx -> {
LogicalAggregate<LogicalProject<LogicalFileScan>> agg = ctx.root;
LogicalProject<LogicalFileScan> project = agg.child();
LogicalFileScan fileScan = project.child();
return storageLayerAggregate(agg, project, fileScan, ctx.cascadesContext);
})
),
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
// basePattern
// .when(this::containsCountDistinctMultiExpr)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() > 1
&& !containsCountDistinctMultiExpr(agg)
&& couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
// basePattern
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
* output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
* GATHER)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
* output(partial_count(name), partial_sum(partial_sum(age))),
* hash distribute by name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
* output(name, partial_sum(age)),
* hash_distribute by name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
* +--scan(name, age)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(
ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
agg -> RequireProperties.of(PhysicalProperties.GATHER);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireDistinctHash, fourPhaseRequireGather
);
})
),
/*
* sql:
* select age, count(distinct name) from student group by age;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
* output[age, sum(partial_count(name))],
* hash distribute by name)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
* output(age, partial_count(name)),
* hash distribute by age, name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
* output(age, name),
* hash_distribute by age, name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
* +--scan(age, name)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
basePattern
.when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
.when(agg ->
ImmutableSet.builder()
.addAll(agg.getGroupByExpressions())
.addAll(agg.getDistinctArguments())
.build().size() > agg.getGroupByExpressions().size()
)
.when(agg -> {
if (agg.getDistinctArguments().size() == 1) {
return true;
}
return couldConvertToMulti(agg);
})
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
agg -> RequireProperties.of(
PhysicalProperties.createHash(
agg.getGroupByExpressions(), ShuffleType.REQUIRE
)
);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
);
})
)
);
}