in flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java [2251:2584]
public void onMatch(RelOptRuleCall call) {
final Correlate correlate = call.rel(0);
final RelNode left = call.rel(1);
final Project aggOutputProject = call.rel(2);
final Aggregate aggregate = call.rel(3);
final Project aggInputProject = call.rel(4);
RelNode right = call.rel(5);
final RelBuilder builder = call.builder();
final RexBuilder rexBuilder = builder.getRexBuilder();
final RelOptCluster cluster = correlate.getCluster();
d.setCurrent(call.getPlanner().getRoot(), correlate);
// check for this pattern
// The pattern matching could be simplified if rules can be applied
// during decorrelation,
//
// CorrelateRel(left correlation, condition = true)
// leftInput
// Project-A (a RexNode)
// Aggregate (groupby (0), agg0(), agg1()...)
// Project-B (references coVar)
// rightInput
// check aggOutputProject projects only one expression
final List<RexNode> aggOutputProjects = aggOutputProject.getProjects();
if (aggOutputProjects.size() != 1) {
return;
}
final JoinRelType joinType = correlate.getJoinType();
// corRel.getCondition was here, however Correlate was updated so it
// never includes a join condition. The code was not modified for brevity.
RexNode joinCond = rexBuilder.makeLiteral(true);
if ((joinType != JoinRelType.LEFT) || (joinCond != rexBuilder.makeLiteral(true))) {
return;
}
// check that the agg is on the entire input
if (!aggregate.getGroupSet().isEmpty()) {
return;
}
final List<RexNode> aggInputProjects = aggInputProject.getProjects();
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final Set<Integer> isCountStar = new HashSet<>();
// mark if agg produces count(*) which needs to reference the
// nullIndicator after the transformation.
int k = -1;
for (AggregateCall aggCall : aggCalls) {
++k;
if ((aggCall.getAggregation() instanceof SqlCountAggFunction)
&& (aggCall.getArgList().size() == 0)) {
isCountStar.add(k);
}
}
if ((right instanceof Filter) && d.cm.mapRefRelToCorRef.containsKey(right)) {
// rightInput has this shape:
//
// Filter (references corVar)
// filterInput
Filter filter = (Filter) right;
right = filter.getInput();
assert right instanceof HepRelVertex;
right = ((HepRelVertex) right).getCurrentRel();
// check filter input contains no correlation
if (RelOptUtil.getVariablesUsed(right).size() > 0) {
return;
}
// check filter condition type First extract the correlation out
// of the filter
// First breaking up the filter conditions into equality
// comparisons between rightJoinKeys(from the original
// filterInput) and correlatedJoinKeys. correlatedJoinKeys
// can only be RexFieldAccess, while rightJoinKeys can be
// expressions. These comparisons are AND'ed together.
List<RexNode> rightJoinKeys = new ArrayList<>();
List<RexNode> tmpCorrelatedJoinKeys = new ArrayList<>();
RelOptUtil.splitCorrelatedFilterCondition(
filter, rightJoinKeys, tmpCorrelatedJoinKeys, true);
// make sure the correlated reference forms a unique key check
// that the columns referenced in these comparisons form an
// unique key of the leftInput
List<RexFieldAccess> correlatedJoinKeys = new ArrayList<>();
List<RexInputRef> correlatedInputRefJoinKeys = new ArrayList<>();
for (RexNode joinKey : tmpCorrelatedJoinKeys) {
assert joinKey instanceof RexFieldAccess;
correlatedJoinKeys.add((RexFieldAccess) joinKey);
RexNode correlatedInputRef = d.removeCorrelationExpr(joinKey, false);
assert correlatedInputRef instanceof RexInputRef;
correlatedInputRefJoinKeys.add((RexInputRef) correlatedInputRef);
}
// check that the columns referenced in rightJoinKeys form an
// unique key of the filterInput
if (correlatedInputRefJoinKeys.isEmpty()) {
return;
}
// The join filters out the nulls. So, it's ok if there are
// nulls in the join keys.
final RelMetadataQuery mq = call.getMetadataQuery();
if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(
mq, left, correlatedInputRefJoinKeys)) {
SQL2REL_LOGGER.debug("{} are not unique keys for {}", correlatedJoinKeys, left);
return;
}
// check corVar references are valid
if (!d.checkCorVars(correlate, aggInputProject, filter, correlatedJoinKeys)) {
return;
}
// Rewrite the above plan:
//
// Correlate(left correlation, condition = true)
// leftInput
// Project-A (a RexNode)
// Aggregate (groupby(0), agg0(),agg1()...)
// Project-B (may reference corVar)
// Filter (references corVar)
// rightInput (no correlated reference)
//
// to this plan:
//
// Project-A' (all gby keys + rewritten nullable ProjExpr)
// Aggregate (groupby(all left input refs)
// agg0(rewritten expression),
// agg1()...)
// Project-B' (rewritten original projected exprs)
// Join(replace corVar w/ input ref from leftInput)
// leftInput
// rightInput
//
// In the case where agg is count(*) or count($corVar), it is
// changed to count(nullIndicator).
// Note: any non-nullable field from the RHS can be used as
// the indicator however a "true" field is added to the
// projection list from the RHS for simplicity to avoid
// searching for non-null fields.
//
// Project-A' (all gby keys + rewritten nullable ProjExpr)
// Aggregate (groupby(all left input refs),
// count(nullIndicator), other aggs...)
// Project-B' (all left input refs plus
// the rewritten original projected exprs)
// Join(replace corVar to input ref from leftInput)
// leftInput
// Project (everything from rightInput plus
// the nullIndicator "true")
// rightInput
//
// first change the filter condition into a join condition
joinCond = d.removeCorrelationExpr(filter.getCondition(), false);
} else if (d.cm.mapRefRelToCorRef.containsKey(aggInputProject)) {
// check rightInput contains no correlation
if (RelOptUtil.getVariablesUsed(right).size() > 0) {
return;
}
// check corVar references are valid
if (!d.checkCorVars(correlate, aggInputProject, null, null)) {
return;
}
int nFields = left.getRowType().getFieldCount();
ImmutableBitSet allCols = ImmutableBitSet.range(nFields);
// leftInput contains unique keys
// i.e. each row is distinct and can group by on all the left
// fields
final RelMetadataQuery mq = call.getMetadataQuery();
if (!RelMdUtil.areColumnsDefinitelyUnique(mq, left, allCols)) {
SQL2REL_LOGGER.debug("There are no unique keys for {}", left);
return;
}
//
// Rewrite the above plan:
//
// CorrelateRel(left correlation, condition = true)
// leftInput
// Project-A (a RexNode)
// Aggregate (groupby(0), agg0(), agg1()...)
// Project-B (references coVar)
// rightInput (no correlated reference)
//
// to this plan:
//
// Project-A' (all gby keys + rewritten nullable ProjExpr)
// Aggregate (groupby(all left input refs)
// agg0(rewritten expression),
// agg1()...)
// Project-B' (rewritten original projected exprs)
// Join (LOJ cond = true)
// leftInput
// rightInput
//
// In the case where agg is count($corVar), it is changed to
// count(nullIndicator).
// Note: any non-nullable field from the RHS can be used as
// the indicator however a "true" field is added to the
// projection list from the RHS for simplicity to avoid
// searching for non-null fields.
//
// Project-A' (all gby keys + rewritten nullable ProjExpr)
// Aggregate (groupby(all left input refs),
// count(nullIndicator), other aggs...)
// Project-B' (all left input refs plus
// the rewritten original projected exprs)
// Join (replace corVar to input ref from leftInput)
// leftInput
// Project (everything from rightInput plus
// the nullIndicator "true")
// rightInput
} else {
return;
}
RelDataType leftInputFieldType = left.getRowType();
int leftInputFieldCount = leftInputFieldType.getFieldCount();
int joinOutputProjExprCount = leftInputFieldCount + aggInputProjects.size() + 1;
right =
d.createProjectWithAdditionalExprs(
right,
ImmutableList.of(
Pair.of(rexBuilder.makeLiteral(true), "nullIndicator")));
Join join = (Join) d.relBuilder.push(left).push(right).join(joinType, joinCond).build();
// To the consumer of joinOutputProjRel, nullIndicator is located
// at the end
int nullIndicatorPos = join.getRowType().getFieldCount() - 1;
RexInputRef nullIndicator =
new RexInputRef(
nullIndicatorPos,
cluster.getTypeFactory()
.createTypeWithNullability(
join.getRowType()
.getFieldList()
.get(nullIndicatorPos)
.getType(),
true));
// first project all group-by keys plus the transformed agg input
List<RexNode> joinOutputProjects = new ArrayList<>();
// LOJ Join preserves LHS types
for (int i = 0; i < leftInputFieldCount; i++) {
joinOutputProjects.add(
rexBuilder.makeInputRef(
leftInputFieldType.getFieldList().get(i).getType(), i));
}
for (RexNode aggInputProjExpr : aggInputProjects) {
joinOutputProjects.add(
d.removeCorrelationExpr(
aggInputProjExpr, joinType.generatesNullsOnRight(), nullIndicator));
}
joinOutputProjects.add(rexBuilder.makeInputRef(join, nullIndicatorPos));
final RelNode joinOutputProject =
builder.push(join).project(joinOutputProjects).build();
// nullIndicator is now at a different location in the output of
// the join
nullIndicatorPos = joinOutputProjExprCount - 1;
final int groupCount = leftInputFieldCount;
List<AggregateCall> newAggCalls = new ArrayList<>();
k = -1;
for (AggregateCall aggCall : aggCalls) {
++k;
final List<Integer> argList;
if (isCountStar.contains(k)) {
// this is a count(*), transform it to count(nullIndicator)
// the null indicator is located at the end
argList = Collections.singletonList(nullIndicatorPos);
} else {
argList = new ArrayList<>();
for (int aggArg : aggCall.getArgList()) {
argList.add(aggArg + groupCount);
}
}
int filterArg =
aggCall.filterArg < 0 ? aggCall.filterArg : aggCall.filterArg + groupCount;
newAggCalls.add(
aggCall.adaptTo(
joinOutputProject,
argList,
filterArg,
aggregate.getGroupCount(),
groupCount));
}
ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount);
builder.push(joinOutputProject).aggregate(builder.groupKey(groupSet), newAggCalls);
List<RexNode> newAggOutputProjectList = new ArrayList<>();
for (int i : groupSet) {
newAggOutputProjectList.add(rexBuilder.makeInputRef(builder.peek(), i));
}
RexNode newAggOutputProjects = d.removeCorrelationExpr(aggOutputProjects.get(0), false);
newAggOutputProjectList.add(
rexBuilder.makeCast(
cluster.getTypeFactory()
.createTypeWithNullability(
newAggOutputProjects.getType(), true),
newAggOutputProjects));
builder.project(newAggOutputProjectList);
call.transformTo(builder.build());
d.removeCorVarFromTree(correlate);
}