in flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecOverAggregate.java [222:432]
private List<OverWindowFrame> createOverWindowFrames(
FlinkTypeFactory typeFactory,
FlinkRelBuilder relBuilder,
ExecNodeConfig config,
ClassLoader classLoader,
RowType inputType,
SortSpec sortSpec,
RowType inputTypeWithConstants) {
final List<OverWindowFrame> windowFrames = new ArrayList<>();
for (GroupSpec group : overSpec.getGroups()) {
OverWindowMode mode = inferGroupMode(group);
if (mode == OverWindowMode.OFFSET) {
for (AggregateCall aggCall : group.getAggCalls()) {
AggregateInfoList aggInfoList =
AggregateUtil.transformToBatchAggregateInfoList(
typeFactory,
inputTypeWithConstants,
JavaScalaConversionUtil.toScala(
Collections.singletonList(aggCall)),
new boolean[] {
true
}, /* needRetraction = true, See LeadLagAggFunction */
sortSpec.getFieldIndices());
AggsHandlerCodeGenerator generator =
new AggsHandlerCodeGenerator(
new CodeGeneratorContext(config, classLoader),
relBuilder,
JavaScalaConversionUtil.toScala(inputType.getChildren()),
false); // copyInputField
// over agg code gen must pass the constants
GeneratedAggsHandleFunction genAggsHandler =
generator
.needAccumulate()
.needRetract()
.withConstants(JavaScalaConversionUtil.toScala(getConstants()))
.generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
// LEAD is behind the currentRow, so we need plus offset.
// LAG is in front of the currentRow, so we need minus offset.
long flag = aggCall.getAggregation().kind == SqlKind.LEAD ? 1L : -1L;
final Long offset;
final OffsetOverFrame.CalcOffsetFunc calcOffsetFunc;
// LEAD ( expression [, offset [, default] ] )
// LAG ( expression [, offset [, default] ] )
// The second arg mean the offset arg index for leag/lag function, default is 1.
if (aggCall.getArgList().size() >= 2) {
int constantIndex =
aggCall.getArgList().get(1) - overSpec.getOriginalInputFields();
if (constantIndex < 0) {
offset = null;
int rowIndex = aggCall.getArgList().get(1);
switch (inputType.getTypeAt(rowIndex).getTypeRoot()) {
case BIGINT:
calcOffsetFunc = row -> row.getLong(rowIndex) * flag;
break;
case INTEGER:
calcOffsetFunc = row -> (long) row.getInt(rowIndex) * flag;
break;
case SMALLINT:
calcOffsetFunc = row -> (long) row.getShort(rowIndex) * flag;
break;
default:
throw new RuntimeException(
"The column type must be in long/int/short.");
}
} else {
long constantOffset =
getConstants().get(constantIndex).getValueAs(Long.class);
offset = constantOffset * flag;
calcOffsetFunc = null;
}
} else {
offset = flag;
calcOffsetFunc = null;
}
windowFrames.add(new OffsetOverFrame(genAggsHandler, offset, calcOffsetFunc));
}
} else {
AggregateInfoList aggInfoList =
AggregateUtil.transformToBatchAggregateInfoList(
typeFactory,
// use aggInputType which considers constants as input instead of
// inputSchema.relDataType
inputTypeWithConstants,
JavaScalaConversionUtil.toScala(group.getAggCalls()),
null, // aggCallNeedRetractions
sortSpec.getFieldIndices());
AggsHandlerCodeGenerator generator =
new AggsHandlerCodeGenerator(
new CodeGeneratorContext(config, classLoader),
relBuilder,
JavaScalaConversionUtil.toScala(inputType.getChildren()),
false); // copyInputField
if (Arrays.stream(aggInfoList.aggInfos())
.anyMatch(f -> f.function() instanceof SizeBasedWindowFunction)) {
generator.needWindowSize();
}
// over agg code gen must pass the constants
GeneratedAggsHandleFunction genAggsHandler =
generator
.needAccumulate()
.withConstants(JavaScalaConversionUtil.toScala(getConstants()))
.generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
RowType valueType = generator.valueType();
final OverWindowFrame frame;
switch (mode) {
case RANGE:
if (isUnboundedWindow(group)) {
frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
} else if (isUnboundedPrecedingWindow(group)) {
GeneratedRecordComparator genBoundComparator =
createBoundComparator(
relBuilder,
config,
classLoader,
sortSpec,
group.getUpperBound(),
false,
inputType);
frame =
new RangeUnboundedPrecedingOverFrame(
genAggsHandler, genBoundComparator);
} else if (isUnboundedFollowingWindow(group)) {
GeneratedRecordComparator genBoundComparator =
createBoundComparator(
relBuilder,
config,
classLoader,
sortSpec,
group.getLowerBound(),
true,
inputType);
frame =
new RangeUnboundedFollowingOverFrame(
valueType, genAggsHandler, genBoundComparator);
} else if (isSlidingWindow(group)) {
GeneratedRecordComparator genLeftBoundComparator =
createBoundComparator(
relBuilder,
config,
classLoader,
sortSpec,
group.getLowerBound(),
true,
inputType);
GeneratedRecordComparator genRightBoundComparator =
createBoundComparator(
relBuilder,
config,
classLoader,
sortSpec,
group.getUpperBound(),
false,
inputType);
frame =
new RangeSlidingOverFrame(
inputType,
valueType,
genAggsHandler,
genLeftBoundComparator,
genRightBoundComparator);
} else {
throw new TableException("This should not happen.");
}
break;
case ROW:
if (isUnboundedWindow(group)) {
frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
} else if (isUnboundedPrecedingWindow(group)) {
frame =
new RowUnboundedPrecedingOverFrame(
genAggsHandler,
OverAggregateUtil.getLongBoundary(
overSpec, group.getUpperBound()));
} else if (isUnboundedFollowingWindow(group)) {
frame =
new RowUnboundedFollowingOverFrame(
valueType,
genAggsHandler,
OverAggregateUtil.getLongBoundary(
overSpec, group.getLowerBound()));
} else if (isSlidingWindow(group)) {
frame =
new RowSlidingOverFrame(
inputType,
valueType,
genAggsHandler,
OverAggregateUtil.getLongBoundary(
overSpec, group.getLowerBound()),
OverAggregateUtil.getLongBoundary(
overSpec, group.getUpperBound()));
} else {
throw new TableException("This should not happen.");
}
break;
case INSENSITIVE:
frame = new InsensitiveOverFrame(genAggsHandler);
break;
default:
throw new TableException("This should not happen.");
}
windowFrames.add(frame);
}
}
return windowFrames;
}