in flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java [285:352]
public void initializeState(StateInitializationContext context) throws Exception {
parallelismState =
context.getOperatorStateStore()
.getUnionListState(
new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
.ifPresent(
oldParallelism ->
checkState(
oldParallelism
== containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks(),
"The all-round wrapper operator is recovered with parallelism changed from "
+ oldParallelism
+ " to "
+ containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks()));
latestEpochWatermarkState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
.ifPresent(
oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
// Notes that the list must be sorted.
rawStateEpochState =
context.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("rawStateEpoch", Integer.class));
List<Integer> rawStateEpochs = IteratorUtils.toList(rawStateEpochState.get().iterator());
// Notes that the list must be sorted.
pendingEpochState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>("pendingEpochs", IntSerializer.INSTANCE));
List<Integer> pendingEpochs = IteratorUtils.toList(pendingEpochState.get().iterator());
// Unfortunately, for the raw state we could not call get input stream unless the previous
// records are consumed. We would have to do a "merge" of the two lists.
Iterator<StatePartitionStreamProvider> rawStates =
context.getRawOperatorStateInputs().iterator();
int nextRawStateEntryIndex = 0;
for (int epoch : pendingEpochs) {
checkState(
nextRawStateEntryIndex == rawStateEpochs.size()
|| rawStateEpochs.get(nextRawStateEntryIndex) >= epoch,
String.format(
"Unexpected raw state indices %s and epochs %s",
rawStateEpochs.toString(), pendingEpochs.toString()));
// Let's find how much entries this epoch has.
int numberOfStateEntries = 0;
while (nextRawStateEntryIndex < rawStateEpochs.size()
&& rawStateEpochs.get(nextRawStateEntryIndex) == epoch) {
numberOfStateEntries++;
nextRawStateEntryIndex++;
}
// We first open these operators
getWrappedOperator(epoch, rawStates, numberOfStateEntries);
}
}