flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java [90:613]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>>
        extends AbstractWrapperOperator<T>
        implements StreamOperatorStateHandler.CheckpointedStreamOperator {

    private static final Logger LOG =
            LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);

    private static final String HEAP_KEYED_STATE_NAME =
            "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";

    private static final String ROCKSDB_KEYED_STATE_NAME =
            "org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";

    /** The wrapped operators for each round. */
    private final Map<Integer, S> wrappedOperators;

    protected final LatencyStats latencyStats;

    private transient StreamOperatorStateContext streamOperatorStateContext;

    private transient StreamOperatorStateHandler stateHandler;

    private transient InternalTimeServiceManager<?> timeServiceManager;

    private transient KeySelector<?, ?> stateKeySelector1;

    private transient KeySelector<?, ?> stateKeySelector2;

    private int latestEpochWatermark = -1;

    // --------------- state ---------------------------

    /** Stores the parallelism of the last run so that we could check if parallelism is changed. */
    private ListState<Integer> parallelismState;

    /** Stores the latest epoch watermark received. */
    private ListState<Integer> latestEpochWatermarkState;

    /**
     * Stores the list of epochs that are not aligned yet. After failover the wrapped operators
     * corresponding to these epochs would be re-crated.
     */
    private ListState<Integer> pendingEpochState;

    private ListState<Integer> rawStateEpochState;

    public AbstractPerRoundWrapperOperator(
            StreamOperatorParameters<IterationRecord<T>> parameters,
            StreamOperatorFactory<T> operatorFactory) {
        super(parameters, operatorFactory);

        this.wrappedOperators = new HashMap<>();
        this.latencyStats = initializeLatencyStats();
    }

    protected S getWrappedOperator(int round) {
        return getWrappedOperator(
                round, CloseableIterable.<StatePartitionStreamProvider>empty().iterator(), 0);
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private S getWrappedOperator(
            int round, Iterator<StatePartitionStreamProvider> rawOperatorStates, int count) {
        S wrappedOperator = wrappedOperators.get(round);
        if (wrappedOperator != null) {
            return wrappedOperator;
        }

        // We need to clone the operator factory to also support SimpleOperatorFactory.
        try {
            StreamOperatorFactory<T> clonedOperatorFactory =
                    InstantiationUtil.clone(
                            operatorFactory, containingTask.getUserCodeClassLoader());
            wrappedOperator =
                    (S)
                            StreamOperatorFactoryUtil.<T, S>createOperator(
                                            clonedOperatorFactory,
                                            (StreamTask) parameters.getContainingTask(),
                                            OperatorUtils.createWrappedOperatorConfig(
                                                    parameters.getStreamConfig(),
                                                    containingTask.getUserCodeClassLoader()),
                                            proxyOutput,
                                            parameters.getOperatorEventDispatcher())
                                    .f0;
            initializeStreamOperator(wrappedOperator, round, rawOperatorStates, count);
            wrappedOperators.put(round, wrappedOperator);
            return wrappedOperator;
        } catch (Exception e) {
            ExceptionUtils.rethrow(e);
        }

        return wrappedOperator;
    }

    protected abstract void endInputAndEmitMaxWatermark(S operator, int epoch, int epochWatermark)
            throws Exception;

    protected void closeStreamOperator(S operator, int epoch, int epochWatermark) throws Exception {
        setIterationContextRound(epoch);
        OperatorUtils.processOperatorOrUdfIfSatisfy(
                operator,
                IterationListener.class,
                listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
        endInputAndEmitMaxWatermark(operator, epoch, epochWatermark);
        operator.finish();
        operator.close();
        setIterationContextRound(null);

        // Cleanup the states used by this operator.
        cleanupOperatorStates(epoch);

        if (stateHandler.getKeyedStateBackend() != null) {
            cleanupKeyedStates(epoch);
        }
    }

    @Override
    public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
        checkState(epochWatermark >= 0, "The epoch watermark should be non-negative.");
        if (epochWatermark > latestEpochWatermark) {
            latestEpochWatermark = epochWatermark;

            // Destroys all the operators with round < epoch watermark. Notes that
            // the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
            // for the last round.
            try {
                if (epochWatermark < Integer.MAX_VALUE) {
                    S wrappedOperator = wrappedOperators.remove(epochWatermark);
                    if (wrappedOperator != null) {
                        closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
                    }
                } else {
                    List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
                    Collections.sort(sortedEpochs);
                    for (Integer epoch : sortedEpochs) {
                        closeStreamOperator(wrappedOperators.remove(epoch), epoch, epochWatermark);
                    }
                }
            } catch (Exception exception) {
                ExceptionUtils.rethrow(exception);
            }
        }

        super.onEpochWatermarkIncrement(epochWatermark);
    }

    protected void processForEachWrappedOperator(
            BiConsumerWithException<Integer, S, Exception> consumer) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            consumer.accept(entry.getKey(), entry.getValue());
        }
    }

    @Override
    public void open() throws Exception {}

    @Override
    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
            throws Exception {
        final TypeSerializer<?> keySerializer =
                streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());

        streamOperatorStateContext =
                streamTaskStateManager.streamOperatorStateContext(
                        getOperatorID(),
                        getClass().getSimpleName(),
                        parameters.getProcessingTimeService(),
                        this,
                        keySerializer,
                        containingTask.getCancelables(),
                        metrics,
                        streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
                                ManagedMemoryUseCase.STATE_BACKEND,
                                containingTask
                                        .getEnvironment()
                                        .getTaskManagerInfo()
                                        .getConfiguration(),
                                containingTask.getUserCodeClassLoader()),
                        isUsingCustomRawKeyedState());

        stateHandler =
                new StreamOperatorStateHandler(
                        streamOperatorStateContext,
                        containingTask.getExecutionConfig(),
                        containingTask.getCancelables());
        stateHandler.initializeOperatorState(this);
        this.timeServiceManager = streamOperatorStateContext.internalTimerServiceManager();

        stateKeySelector1 =
                streamConfig.getStatePartitioner(0, containingTask.getUserCodeClassLoader());
        stateKeySelector2 =
                streamConfig.getStatePartitioner(1, containingTask.getUserCodeClassLoader());
    }

    @Override
    public void initializeState(StateInitializationContext context) throws Exception {
        parallelismState =
                context.getOperatorStateStore()
                        .getUnionListState(
                                new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
                .ifPresent(
                        oldParallelism ->
                                checkState(
                                        oldParallelism
                                                == containingTask
                                                        .getEnvironment()
                                                        .getTaskInfo()
                                                        .getNumberOfParallelSubtasks(),
                                        "The all-round wrapper operator is recovered with parallelism changed from "
                                                + oldParallelism
                                                + " to "
                                                + containingTask
                                                        .getEnvironment()
                                                        .getTaskInfo()
                                                        .getNumberOfParallelSubtasks()));

        latestEpochWatermarkState =
                context.getOperatorStateStore()
                        .getListState(
                                new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
                .ifPresent(
                        oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);

        // Notes that the list must be sorted.
        rawStateEpochState =
                context.getOperatorStateStore()
                        .getListState(new ListStateDescriptor<>("rawStateEpoch", Integer.class));
        List<Integer> rawStateEpochs = IteratorUtils.toList(rawStateEpochState.get().iterator());

        // Notes that the list must be sorted.
        pendingEpochState =
                context.getOperatorStateStore()
                        .getListState(
                                new ListStateDescriptor<>("pendingEpochs", IntSerializer.INSTANCE));
        List<Integer> pendingEpochs = IteratorUtils.toList(pendingEpochState.get().iterator());

        // Unfortunately, for the raw state we could not call get input stream unless the previous
        // records are consumed. We would have to do a "merge" of the two lists.
        Iterator<StatePartitionStreamProvider> rawStates =
                context.getRawOperatorStateInputs().iterator();

        int nextRawStateEntryIndex = 0;
        for (int epoch : pendingEpochs) {
            checkState(
                    nextRawStateEntryIndex == rawStateEpochs.size()
                            || rawStateEpochs.get(nextRawStateEntryIndex) >= epoch,
                    String.format(
                            "Unexpected raw state indices %s and epochs %s",
                            rawStateEpochs.toString(), pendingEpochs.toString()));
            // Let's find how much entries this epoch has.
            int numberOfStateEntries = 0;
            while (nextRawStateEntryIndex < rawStateEpochs.size()
                    && rawStateEpochs.get(nextRawStateEntryIndex) == epoch) {
                numberOfStateEntries++;
                nextRawStateEntryIndex++;
            }

            // We first open these operators
            getWrappedOperator(epoch, rawStates, numberOfStateEntries);
        }
    }

    @Internal
    protected boolean isUsingCustomRawKeyedState() {
        return false;
    }

    @Override
    public void finish() throws Exception {
        checkState(
                wrappedOperators.size() == 0,
                "Some wrapped operators are still not closed yet: " + wrappedOperators.keySet());
    }

    @Override
    public void close() throws Exception {
        if (stateHandler != null) {
            stateHandler.dispose();
        }
    }

    @Override
    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().prepareSnapshotPreBarrier(checkpointId);
        }
    }

    @Override
    public OperatorSnapshotFutures snapshotState(
            long checkpointId,
            long timestamp,
            CheckpointOptions checkpointOptions,
            CheckpointStreamFactory factory)
            throws Exception {
        return stateHandler.snapshotState(
                this,
                Optional.ofNullable(timeServiceManager),
                streamConfig.getOperatorName(),
                checkpointId,
                timestamp,
                checkpointOptions,
                factory,
                isUsingCustomRawKeyedState());
    }

    @Override
    public void snapshotState(StateSnapshotContext context) throws Exception {
        OperatorStateCheckpointOutputStream rawOperatorStateOutputStream =
                context.getRawOperatorStateOutput();
        List<Integer> operatorStateEpoch = new ArrayList<>();

        List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
        Collections.sort(sortedEpochs);

        for (int epoch : sortedEpochs) {
            S wrappedOperator = wrappedOperators.get(epoch);
            if (StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(
                    wrappedOperator.getClass())) {
                ((StreamOperatorStateHandler.CheckpointedStreamOperator) wrappedOperator)
                        .snapshotState(new ProxyStateSnapshotContext(context));

                // Gets the count of the raw operator state.
                int numberOfPartitions = rawOperatorStateOutputStream.getNumberOfPartitions();
                while (operatorStateEpoch.size() < numberOfPartitions) {
                    operatorStateEpoch.add(epoch);
                }
            }
        }

        // Then snapshot our own states
        // Always clear the union list state before set value.
        parallelismState.clear();
        if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
            parallelismState.update(
                    Collections.singletonList(
                            containingTask
                                    .getEnvironment()
                                    .getTaskInfo()
                                    .getNumberOfParallelSubtasks()));
        }
        latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));

        // The list must be sorted
        rawStateEpochState.update(operatorStateEpoch);

        // The list must be sorted
        pendingEpochState.update(sortedEpochs);
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public void setKeyContextElement1(StreamRecord record) throws Exception {
        setKeyContextElement(record, stateKeySelector1);
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public void setKeyContextElement2(StreamRecord record) throws Exception {
        setKeyContextElement(record, stateKeySelector2);
    }

    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
            throws Exception {
        if (selector != null
                && ((IterationRecord<?>) record.getValue()).getType()
                        == IterationRecord.Type.RECORD) {
            Object key = selector.getKey(record.getValue());
            setCurrentKey(key);
        }
    }

    @Override
    public OperatorMetricGroup getMetricGroup() {
        return metrics;
    }

    @Override
    public OperatorID getOperatorID() {
        return streamConfig.getOperatorID();
    }

    @Override
    public void notifyCheckpointComplete(long l) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().notifyCheckpointComplete(l);
        }
    }

    @Override
    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().notifyCheckpointAborted(checkpointId);
        }
    }

    @Override
    public void setCurrentKey(Object key) {
        stateHandler.setCurrentKey(key);
    }

    @Override
    public Object getCurrentKey() {
        if (stateHandler == null) {
            return null;
        }

        return stateHandler.getCurrentKey();
    }

    protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
        // all operators are tracking latencies
        this.latencyStats.reportLatency(marker);

        // everything except sinks forwards latency markers
        this.output.emitLatencyMarker(marker);
    }

    private LatencyStats initializeLatencyStats() {
        try {
            Configuration taskManagerConfig =
                    containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
            int historySize = taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
            if (historySize <= 0) {
                LOG.warn(
                        "{} has been set to a value equal or below 0: {}. Using default.",
                        MetricOptions.LATENCY_HISTORY_SIZE,
                        historySize);
                historySize = MetricOptions.LATENCY_HISTORY_SIZE.defaultValue();
            }

            final String configuredGranularity =
                    taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
            LatencyStats.Granularity granularity;
            try {
                granularity =
                        LatencyStats.Granularity.valueOf(
                                configuredGranularity.toUpperCase(Locale.ROOT));
            } catch (IllegalArgumentException iae) {
                granularity = LatencyStats.Granularity.OPERATOR;
                LOG.warn(
                        "Configured value {} option for {} is invalid. Defaulting to {}.",
                        configuredGranularity,
                        MetricOptions.LATENCY_SOURCE_GRANULARITY.key(),
                        granularity);
            }
            MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup();
            return new LatencyStats(
                    jobMetricGroup.addGroup("latency"),
                    historySize,
                    containingTask.getIndexInSubtaskGroup(),
                    getOperatorID(),
                    granularity);
        } catch (Exception e) {
            LOG.warn("An error occurred while instantiating latency metrics.", e);
            return new LatencyStats(
                    UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup()
                            .addGroup("latency"),
                    1,
                    0,
                    new OperatorID(),
                    LatencyStats.Granularity.SINGLE);
        }
    }

    private void initializeStreamOperator(
            S operator,
            int round,
            Iterator<StatePartitionStreamProvider> rawOperatorStates,
            int count)
            throws Exception {
        operator.initializeState(
                (operatorID,
                        operatorClassName,
                        processingTimeService,
                        keyContext,
                        keySerializer,
                        streamTaskCloseableRegistry,
                        metricGroup,
                        managedMemoryFraction,
                        isUsingCustomRawKeyedState) ->
                        new ProxyStreamOperatorStateContext(
                                streamOperatorStateContext,
                                getRoundStatePrefix(round),
                                rawOperatorStates,
                                count));
        operator.open();
    }

    private void cleanupOperatorStates(int round) {
        String roundPrefix = getRoundStatePrefix(round);
        OperatorStateBackend operatorStateBackend = stateHandler.getOperatorStateBackend();

        if (operatorStateBackend instanceof DefaultOperatorStateBackend) {
            for (String fieldNames :
                    new String[] {
                        "registeredOperatorStates",
                        "registeredBroadcastStates",
                        "accessedStatesByName",
                        "accessedBroadcastStatesByName"
                    }) {
                Map<String, ?> field =
                        ReflectionUtils.getFieldValue(
                                operatorStateBackend,
                                DefaultOperatorStateBackend.class,
                                fieldNames);
                field.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix));
            }
        } else {
            LOG.warn("Unable to cleanup the operator state {}", operatorStateBackend);
        }
    }

    private void cleanupKeyedStates(int round) {
        String roundPrefix = getRoundStatePrefix(round);
        KeyedStateBackend<?> keyedStateBackend = stateHandler.getKeyedStateBackend();
        if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
            ReflectionUtils.<Map<String, ?>>getFieldValue(
                            keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates")
                    .entrySet()
                    .removeIf(entry -> entry.getKey().startsWith(roundPrefix));
            ReflectionUtils.<Map<String, ?>>getFieldValue(
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java [90:613]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>>
        extends AbstractWrapperOperator<T>
        implements StreamOperatorStateHandler.CheckpointedStreamOperator {

    private static final Logger LOG =
            LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);

    private static final String HEAP_KEYED_STATE_NAME =
            "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";

    private static final String ROCKSDB_KEYED_STATE_NAME =
            "org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";

    /** The wrapped operators for each round. */
    private final Map<Integer, S> wrappedOperators;

    protected final LatencyStats latencyStats;

    private transient StreamOperatorStateContext streamOperatorStateContext;

    private transient StreamOperatorStateHandler stateHandler;

    private transient InternalTimeServiceManager<?> timeServiceManager;

    private transient KeySelector<?, ?> stateKeySelector1;

    private transient KeySelector<?, ?> stateKeySelector2;

    private int latestEpochWatermark = -1;

    // --------------- state ---------------------------

    /** Stores the parallelism of the last run so that we could check if parallelism is changed. */
    private ListState<Integer> parallelismState;

    /** Stores the latest epoch watermark received. */
    private ListState<Integer> latestEpochWatermarkState;

    /**
     * Stores the list of epochs that are not aligned yet. After failover the wrapped operators
     * corresponding to these epochs would be re-crated.
     */
    private ListState<Integer> pendingEpochState;

    private ListState<Integer> rawStateEpochState;

    public AbstractPerRoundWrapperOperator(
            StreamOperatorParameters<IterationRecord<T>> parameters,
            StreamOperatorFactory<T> operatorFactory) {
        super(parameters, operatorFactory);

        this.wrappedOperators = new HashMap<>();
        this.latencyStats = initializeLatencyStats();
    }

    protected S getWrappedOperator(int round) {
        return getWrappedOperator(
                round, CloseableIterable.<StatePartitionStreamProvider>empty().iterator(), 0);
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private S getWrappedOperator(
            int round, Iterator<StatePartitionStreamProvider> rawOperatorStates, int count) {
        S wrappedOperator = wrappedOperators.get(round);
        if (wrappedOperator != null) {
            return wrappedOperator;
        }

        // We need to clone the operator factory to also support SimpleOperatorFactory.
        try {
            StreamOperatorFactory<T> clonedOperatorFactory =
                    InstantiationUtil.clone(
                            operatorFactory, containingTask.getUserCodeClassLoader());
            wrappedOperator =
                    (S)
                            StreamOperatorFactoryUtil.<T, S>createOperator(
                                            clonedOperatorFactory,
                                            (StreamTask) parameters.getContainingTask(),
                                            OperatorUtils.createWrappedOperatorConfig(
                                                    parameters.getStreamConfig(),
                                                    containingTask.getUserCodeClassLoader()),
                                            proxyOutput,
                                            parameters.getOperatorEventDispatcher())
                                    .f0;
            initializeStreamOperator(wrappedOperator, round, rawOperatorStates, count);
            wrappedOperators.put(round, wrappedOperator);
            return wrappedOperator;
        } catch (Exception e) {
            ExceptionUtils.rethrow(e);
        }

        return wrappedOperator;
    }

    protected abstract void endInputAndEmitMaxWatermark(S operator, int epoch, int epochWatermark)
            throws Exception;

    protected void closeStreamOperator(S operator, int epoch, int epochWatermark) throws Exception {
        setIterationContextRound(epoch);
        OperatorUtils.processOperatorOrUdfIfSatisfy(
                operator,
                IterationListener.class,
                listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
        endInputAndEmitMaxWatermark(operator, epoch, epochWatermark);
        operator.finish();
        operator.close();
        setIterationContextRound(null);

        // Cleanup the states used by this operator.
        cleanupOperatorStates(epoch);

        if (stateHandler.getKeyedStateBackend() != null) {
            cleanupKeyedStates(epoch);
        }
    }

    @Override
    public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
        checkState(epochWatermark >= 0, "The epoch watermark should be non-negative.");
        if (epochWatermark > latestEpochWatermark) {
            latestEpochWatermark = epochWatermark;

            // Destroys all the operators with round < epoch watermark. Notes that
            // the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
            // for the last round.
            try {
                if (epochWatermark < Integer.MAX_VALUE) {
                    S wrappedOperator = wrappedOperators.remove(epochWatermark);
                    if (wrappedOperator != null) {
                        closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
                    }
                } else {
                    List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
                    Collections.sort(sortedEpochs);
                    for (Integer epoch : sortedEpochs) {
                        closeStreamOperator(wrappedOperators.remove(epoch), epoch, epochWatermark);
                    }
                }
            } catch (Exception exception) {
                ExceptionUtils.rethrow(exception);
            }
        }

        super.onEpochWatermarkIncrement(epochWatermark);
    }

    protected void processForEachWrappedOperator(
            BiConsumerWithException<Integer, S, Exception> consumer) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            consumer.accept(entry.getKey(), entry.getValue());
        }
    }

    @Override
    public void open() throws Exception {}

    @Override
    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
            throws Exception {
        final TypeSerializer<?> keySerializer =
                streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());

        streamOperatorStateContext =
                streamTaskStateManager.streamOperatorStateContext(
                        getOperatorID(),
                        getClass().getSimpleName(),
                        parameters.getProcessingTimeService(),
                        this,
                        keySerializer,
                        containingTask.getCancelables(),
                        metrics,
                        streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
                                ManagedMemoryUseCase.STATE_BACKEND,
                                containingTask
                                        .getEnvironment()
                                        .getTaskManagerInfo()
                                        .getConfiguration(),
                                containingTask.getUserCodeClassLoader()),
                        isUsingCustomRawKeyedState());

        stateHandler =
                new StreamOperatorStateHandler(
                        streamOperatorStateContext,
                        containingTask.getExecutionConfig(),
                        containingTask.getCancelables());
        stateHandler.initializeOperatorState(this);
        this.timeServiceManager = streamOperatorStateContext.internalTimerServiceManager();

        stateKeySelector1 =
                streamConfig.getStatePartitioner(0, containingTask.getUserCodeClassLoader());
        stateKeySelector2 =
                streamConfig.getStatePartitioner(1, containingTask.getUserCodeClassLoader());
    }

    @Override
    public void initializeState(StateInitializationContext context) throws Exception {
        parallelismState =
                context.getOperatorStateStore()
                        .getUnionListState(
                                new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
                .ifPresent(
                        oldParallelism ->
                                checkState(
                                        oldParallelism
                                                == containingTask
                                                        .getEnvironment()
                                                        .getTaskInfo()
                                                        .getNumberOfParallelSubtasks(),
                                        "The all-round wrapper operator is recovered with parallelism changed from "
                                                + oldParallelism
                                                + " to "
                                                + containingTask
                                                        .getEnvironment()
                                                        .getTaskInfo()
                                                        .getNumberOfParallelSubtasks()));

        latestEpochWatermarkState =
                context.getOperatorStateStore()
                        .getListState(
                                new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
                .ifPresent(
                        oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);

        // Notes that the list must be sorted.
        rawStateEpochState =
                context.getOperatorStateStore()
                        .getListState(new ListStateDescriptor<>("rawStateEpoch", Integer.class));
        List<Integer> rawStateEpochs = IteratorUtils.toList(rawStateEpochState.get().iterator());

        // Notes that the list must be sorted.
        pendingEpochState =
                context.getOperatorStateStore()
                        .getListState(
                                new ListStateDescriptor<>("pendingEpochs", IntSerializer.INSTANCE));
        List<Integer> pendingEpochs = IteratorUtils.toList(pendingEpochState.get().iterator());

        // Unfortunately, for the raw state we could not call get input stream unless the previous
        // records are consumed. We would have to do a "merge" of the two lists.
        Iterator<StatePartitionStreamProvider> rawStates =
                context.getRawOperatorStateInputs().iterator();

        int nextRawStateEntryIndex = 0;
        for (int epoch : pendingEpochs) {
            checkState(
                    nextRawStateEntryIndex == rawStateEpochs.size()
                            || rawStateEpochs.get(nextRawStateEntryIndex) >= epoch,
                    String.format(
                            "Unexpected raw state indices %s and epochs %s",
                            rawStateEpochs.toString(), pendingEpochs.toString()));
            // Let's find how much entries this epoch has.
            int numberOfStateEntries = 0;
            while (nextRawStateEntryIndex < rawStateEpochs.size()
                    && rawStateEpochs.get(nextRawStateEntryIndex) == epoch) {
                numberOfStateEntries++;
                nextRawStateEntryIndex++;
            }

            // We first open these operators
            getWrappedOperator(epoch, rawStates, numberOfStateEntries);
        }
    }

    @Internal
    protected boolean isUsingCustomRawKeyedState() {
        return false;
    }

    @Override
    public void finish() throws Exception {
        checkState(
                wrappedOperators.size() == 0,
                "Some wrapped operators are still not closed yet: " + wrappedOperators.keySet());
    }

    @Override
    public void close() throws Exception {
        if (stateHandler != null) {
            stateHandler.dispose();
        }
    }

    @Override
    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().prepareSnapshotPreBarrier(checkpointId);
        }
    }

    @Override
    public OperatorSnapshotFutures snapshotState(
            long checkpointId,
            long timestamp,
            CheckpointOptions checkpointOptions,
            CheckpointStreamFactory factory)
            throws Exception {
        return stateHandler.snapshotState(
                this,
                Optional.ofNullable(timeServiceManager),
                streamConfig.getOperatorName(),
                checkpointId,
                timestamp,
                checkpointOptions,
                factory,
                isUsingCustomRawKeyedState());
    }

    @Override
    public void snapshotState(StateSnapshotContext context) throws Exception {
        OperatorStateCheckpointOutputStream rawOperatorStateOutputStream =
                context.getRawOperatorStateOutput();
        List<Integer> operatorStateEpoch = new ArrayList<>();

        List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
        Collections.sort(sortedEpochs);

        for (int epoch : sortedEpochs) {
            S wrappedOperator = wrappedOperators.get(epoch);
            if (StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(
                    wrappedOperator.getClass())) {
                ((StreamOperatorStateHandler.CheckpointedStreamOperator) wrappedOperator)
                        .snapshotState(new ProxyStateSnapshotContext(context));

                // Gets the count of the raw operator state.
                int numberOfPartitions = rawOperatorStateOutputStream.getNumberOfPartitions();
                while (operatorStateEpoch.size() < numberOfPartitions) {
                    operatorStateEpoch.add(epoch);
                }
            }
        }

        // Then snapshot our own states
        // Always clear the union list state before set value.
        parallelismState.clear();
        if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
            parallelismState.update(
                    Collections.singletonList(
                            containingTask
                                    .getEnvironment()
                                    .getTaskInfo()
                                    .getNumberOfParallelSubtasks()));
        }
        latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));

        // The list must be sorted
        rawStateEpochState.update(operatorStateEpoch);

        // The list must be sorted
        pendingEpochState.update(sortedEpochs);
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public void setKeyContextElement1(StreamRecord record) throws Exception {
        setKeyContextElement(record, stateKeySelector1);
    }

    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public void setKeyContextElement2(StreamRecord record) throws Exception {
        setKeyContextElement(record, stateKeySelector2);
    }

    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
            throws Exception {
        if (selector != null
                && ((IterationRecord<?>) record.getValue()).getType()
                        == IterationRecord.Type.RECORD) {
            Object key = selector.getKey(record.getValue());
            setCurrentKey(key);
        }
    }

    @Override
    public OperatorMetricGroup getMetricGroup() {
        return metrics;
    }

    @Override
    public OperatorID getOperatorID() {
        return streamConfig.getOperatorID();
    }

    @Override
    public void notifyCheckpointComplete(long l) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().notifyCheckpointComplete(l);
        }
    }

    @Override
    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
            entry.getValue().notifyCheckpointAborted(checkpointId);
        }
    }

    @Override
    public void setCurrentKey(Object key) {
        stateHandler.setCurrentKey(key);
    }

    @Override
    public Object getCurrentKey() {
        if (stateHandler == null) {
            return null;
        }

        return stateHandler.getCurrentKey();
    }

    protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
        // all operators are tracking latencies
        this.latencyStats.reportLatency(marker);

        // everything except sinks forwards latency markers
        this.output.emitLatencyMarker(marker);
    }

    private LatencyStats initializeLatencyStats() {
        try {
            Configuration taskManagerConfig =
                    containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
            int historySize = taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
            if (historySize <= 0) {
                LOG.warn(
                        "{} has been set to a value equal or below 0: {}. Using default.",
                        MetricOptions.LATENCY_HISTORY_SIZE,
                        historySize);
                historySize = MetricOptions.LATENCY_HISTORY_SIZE.defaultValue();
            }

            final String configuredGranularity =
                    taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
            LatencyStats.Granularity granularity;
            try {
                granularity =
                        LatencyStats.Granularity.valueOf(
                                configuredGranularity.toUpperCase(Locale.ROOT));
            } catch (IllegalArgumentException iae) {
                granularity = LatencyStats.Granularity.OPERATOR;
                LOG.warn(
                        "Configured value {} option for {} is invalid. Defaulting to {}.",
                        configuredGranularity,
                        MetricOptions.LATENCY_SOURCE_GRANULARITY.key(),
                        granularity);
            }
            MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup();
            return new LatencyStats(
                    jobMetricGroup.addGroup("latency"),
                    historySize,
                    containingTask.getIndexInSubtaskGroup(),
                    getOperatorID(),
                    granularity);
        } catch (Exception e) {
            LOG.warn("An error occurred while instantiating latency metrics.", e);
            return new LatencyStats(
                    UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup()
                            .addGroup("latency"),
                    1,
                    0,
                    new OperatorID(),
                    LatencyStats.Granularity.SINGLE);
        }
    }

    private void initializeStreamOperator(
            S operator,
            int round,
            Iterator<StatePartitionStreamProvider> rawOperatorStates,
            int count)
            throws Exception {
        operator.initializeState(
                (operatorID,
                        operatorClassName,
                        processingTimeService,
                        keyContext,
                        keySerializer,
                        streamTaskCloseableRegistry,
                        metricGroup,
                        managedMemoryFraction,
                        isUsingCustomRawKeyedState) ->
                        new ProxyStreamOperatorStateContext(
                                streamOperatorStateContext,
                                getRoundStatePrefix(round),
                                rawOperatorStates,
                                count));
        operator.open();
    }

    private void cleanupOperatorStates(int round) {
        String roundPrefix = getRoundStatePrefix(round);
        OperatorStateBackend operatorStateBackend = stateHandler.getOperatorStateBackend();

        if (operatorStateBackend instanceof DefaultOperatorStateBackend) {
            for (String fieldNames :
                    new String[] {
                        "registeredOperatorStates",
                        "registeredBroadcastStates",
                        "accessedStatesByName",
                        "accessedBroadcastStatesByName"
                    }) {
                Map<String, ?> field =
                        ReflectionUtils.getFieldValue(
                                operatorStateBackend,
                                DefaultOperatorStateBackend.class,
                                fieldNames);
                field.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix));
            }
        } else {
            LOG.warn("Unable to cleanup the operator state {}", operatorStateBackend);
        }
    }

    private void cleanupKeyedStates(int round) {
        String roundPrefix = getRoundStatePrefix(round);
        KeyedStateBackend<?> keyedStateBackend = stateHandler.getKeyedStateBackend();
        if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
            ReflectionUtils.<Map<String, ?>>getFieldValue(
                            keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates")
                    .entrySet()
                    .removeIf(entry -> entry.getKey().startsWith(roundPrefix));
            ReflectionUtils.<Map<String, ?>>getFieldValue(
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



