in flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java [125:190]
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
parallelismState =
context.getOperatorStateStore()
.getUnionListState(
new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
.ifPresent(
oldParallelism ->
checkState(
oldParallelism
== getRuntimeContext()
.getNumberOfParallelSubtasks(),
"The Replay operator is recovered with parallelism changed from "
+ oldParallelism
+ " to "
+ getRuntimeContext()
.getNumberOfParallelSubtasks()));
currentEpochState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<Integer>("epoch", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(currentEpochState, "epoch")
.ifPresent(epoch -> currentEpoch = epoch);
try {
SupplierWithException<Path, IOException> pathGenerator =
OperatorUtils.createDataCacheFileGenerator(
basePath, "replay", config.getOperatorID());
DataCacheSnapshot dataCacheSnapshot = null;
List<StatePartitionStreamProvider> rawStateInputs =
IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
if (rawStateInputs.size() > 0) {
checkState(
rawStateInputs.size() == 1,
"Currently the replay operator does not support rescaling");
dataCacheSnapshot =
DataCacheSnapshot.recover(
rawStateInputs.get(0).getStream(), fileSystem, pathGenerator);
}
dataCacheWriter =
new DataCacheWriter<>(
typeSerializer,
fileSystem,
pathGenerator,
dataCacheSnapshot == null
? Collections.emptyList()
: dataCacheSnapshot.getSegments());
if (dataCacheSnapshot != null && dataCacheSnapshot.getReaderPosition() != null) {
currentDataCacheReader =
new DataCacheReader<>(
typeSerializer,
dataCacheSnapshot.getSegments(),
dataCacheSnapshot.getReaderPosition());
}
} catch (Exception e) {
throw new FlinkRuntimeException("Failed to replay the records", e);
}
}