flink-ml-iteration/flink-ml-iteration-1.15/src/main/java/org/apache/flink/iteration/Iterations.java [108:596]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@Experimental
public class Iterations {

    /**
     * This method uses an iteration body to process records in possibly unbounded data streams. The
     * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refer in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
        return createIteration(
                initVariableStreams,
                dataStreams,
                Collections.emptySet(),
                body,
                new AllRoundOperatorWrapper(),
                false);
    }

    /**
     * This method uses an iteration body to process records in some bounded data streams
     * iteratively until no more records are iterating or the given terminating criteria stream is
     * empty in one round.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refer in the {@code body} and if each of
     *     them needs replayed for each round.
     * @param config The config for the iteration, like whether to re-create the operator on each
     *     round.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateBoundedStreamsUntilTermination(
            DataStreamList initVariableStreams,
            ReplayableDataStreamList dataStreams,
            IterationConfig config,
            IterationBody body) {
        OperatorWrapper wrapper =
                config.getOperatorLifeCycle() == IterationConfig.OperatorLifeCycle.ALL_ROUND
                        ? new AllRoundOperatorWrapper<>()
                        : new PerRoundOperatorWrapper<>();

        List<DataStream<?>> allDatastreams = new ArrayList<>();
        allDatastreams.addAll(dataStreams.getReplayedDataStreams());
        allDatastreams.addAll(dataStreams.getNonReplayedStreams());

        Set<Integer> replayedIndices =
                IntStream.range(0, dataStreams.getReplayedDataStreams().size())
                        .boxed()
                        .collect(Collectors.toSet());

        return createIteration(
                initVariableStreams,
                new DataStreamList(allDatastreams),
                replayedIndices,
                body,
                wrapper,
                true);
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStreamList createIteration(
            DataStreamList initVariableStreams,
            DataStreamList dataStreams,
            Set<Integer> replayedDataStreamIndices,
            IterationBody body,
            OperatorWrapper<?, IterationRecord<?>> initialOperatorWrapper,
            boolean mayHaveCriteria) {
        checkState(initVariableStreams.size() > 0, "There should be at least one variable stream");

        IterationID iterationId = new IterationID();

        List<TypeInformation<?>> initVariableTypeInfos = getTypeInfos(initVariableStreams);
        List<TypeInformation<?>> dataStreamTypeInfos = getTypeInfos(dataStreams);

        // Add heads and inputs
        int totalInitVariableParallelism =
                map(
                                initVariableStreams,
                                dataStream ->
                                        dataStream.getParallelism() > 0
                                                ? dataStream.getParallelism()
                                                : dataStream
                                                        .getExecutionEnvironment()
                                                        .getConfig()
                                                        .getParallelism())
                        .stream()
                        .mapToInt(i -> i)
                        .sum();
        DataStreamList initVariableInputs = addInputs(initVariableStreams);
        DataStreamList headStreams =
                addHeads(
                        initVariableStreams,
                        initVariableInputs,
                        iterationId,
                        totalInitVariableParallelism,
                        false,
                        0);

        DataStreamList dataStreamInputs = addInputs(dataStreams);
        if (replayedDataStreamIndices.size() > 0) {
            dataStreamInputs =
                    addReplayer(
                            headStreams.get(0),
                            dataStreams,
                            dataStreamInputs,
                            replayedDataStreamIndices);
        }

        // Creates the iteration body. We map the inputs of iteration body into the draft sources,
        // which serve as the start points to build the draft subgraph.
        StreamExecutionEnvironment env = initVariableStreams.get(0).getExecutionEnvironment();
        DraftExecutionEnvironment draftEnv =
                new DraftExecutionEnvironment(env, initialOperatorWrapper);
        DataStreamList draftHeadStreams =
                addDraftSources(headStreams, draftEnv, initVariableTypeInfos);
        DataStreamList draftDataStreamInputs =
                addDraftSources(dataStreamInputs, draftEnv, dataStreamTypeInfos);

        IterationBodyResult iterationBodyResult =
                body.process(draftHeadStreams, draftDataStreamInputs);
        ensuresTransformationAdded(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        ensuresTransformationAdded(iterationBodyResult.getOutputStreams(), draftEnv);
        draftEnv.copyToActualEnvironment();

        // Adds tails and co-locate them with the heads.
        DataStreamList feedbackStreams =
                getActualDataStreams(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        checkState(
                feedbackStreams.size() == initVariableStreams.size(),
                "The number of feedback streams "
                        + feedbackStreams.size()
                        + " does not match the initialized one "
                        + initVariableStreams.size());
        for (int i = 0; i < feedbackStreams.size(); ++i) {
            checkState(
                    feedbackStreams.get(i).getParallelism() == headStreams.get(i).getParallelism(),
                    String.format(
                            "The feedback stream %d have different parallelism %d with the initial stream, which is %d",
                            i,
                            feedbackStreams.get(i).getParallelism(),
                            headStreams.get(i).getParallelism()));
        }

        DataStreamList tails = addTails(feedbackStreams, iterationId, 0);
        for (int i = 0; i < headStreams.size(); ++i) {
            String coLocationGroupKey = "co-" + iterationId.toHexString() + "-" + i;
            headStreams.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
            tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
        }

        List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
        checkState(
                mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
                "The current iteration type does not support the termination criteria.");

        if (iterationBodyResult.getTerminationCriteria() != null) {
            DataStreamList criteriaTails =
                    addCriteriaStream(
                            iterationBodyResult.getTerminationCriteria(),
                            iterationId,
                            env,
                            draftEnv,
                            initVariableStreams,
                            headStreams,
                            totalInitVariableParallelism);
            tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
        }

        DataStream<Integer> tailsUnion =
                unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));

        return addOutputs(
                getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
    }

    private static DataStreamList addReplayer(
            DataStream<?> firstHeadStream,
            DataStreamList originalDataStreams,
            DataStreamList dataStreamInputs,
            Set<Integer> replayedDataStreamIndices) {

        List<DataStream<?>> result = new ArrayList<>(dataStreamInputs.size());
        for (int i = 0; i < dataStreamInputs.size(); ++i) {
            if (!replayedDataStreamIndices.contains(i)) {
                result.add(dataStreamInputs.get(i));
                continue;
            }

            // Notes that the HeadOperator would broadcast the globally aligned events,
            // thus the operator does not require emit to the sideoutput specially.
            DataStream<?> replayedInput =
                    dataStreamInputs
                            .get(i)
                            .connect(
                                    ((SingleOutputStreamOperator<IterationRecord<?>>)
                                                    firstHeadStream)
                                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
                                            .broadcast())
                            .transform(
                                    "Replayer-"
                                            + originalDataStreams
                                                    .get(i)
                                                    .getTransformation()
                                                    .getName(),
                                    dataStreamInputs.get(i).getType(),
                                    (TwoInputStreamOperator) new ReplayOperator<>())
                            .setParallelism(dataStreamInputs.get(i).getParallelism());
            result.add(replayedInput);
        }

        return new DataStreamList(result);
    }

    private static DataStreamList addCriteriaStream(
            DataStream<?> draftCriteriaStream,
            IterationID iterationId,
            StreamExecutionEnvironment env,
            DraftExecutionEnvironment draftEnv,
            DataStreamList initVariableStreams,
            DataStreamList headStreams,
            int totalInitVariableParallelism) {
        // Deals with the criteria streams
        DataStream<?> terminationCriteria = draftEnv.getActualStream(draftCriteriaStream.getId());
        // It should always has the IterationRecordTypeInfo
        checkState(
                terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
                "The termination criteria should always return IterationRecord.");
        TypeInformation<?> innerType =
                ((IterationRecordTypeInfo<?>) terminationCriteria.getType()).getInnerTypeInfo();

        DataStream<?> emptyCriteriaSource =
                env.addSource(new DraftExecutionEnvironment.EmptySource())
                        .returns(innerType)
                        .name(terminationCriteria.getTransformation().getName())
                        .setParallelism(terminationCriteria.getParallelism());
        DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
        DataStreamList criteriaInputs = addInputs(criteriaSources);
        DataStreamList criteriaHeaders =
                addHeads(
                        criteriaSources,
                        criteriaInputs,
                        iterationId,
                        totalInitVariableParallelism,
                        true,
                        initVariableStreams.size());

        // Merges the head and the actual criteria stream. This is required since if we have
        // no edges from the criteria head to the criteria tail, the tail might directly received
        // the MAX_EPOCH_WATERMARK without the synchronization of the head.
        DataStream<?> mergedHeadAndCriteria =
                mergeCriteriaHeadAndCriteriaStream(
                        env, criteriaHeaders.get(0), terminationCriteria, innerType);
        DataStreamList criteriaTails =
                addTails(
                        DataStreamList.of(mergedHeadAndCriteria),
                        iterationId,
                        initVariableStreams.size());

        String coLocationGroupKey = "co-" + iterationId.toHexString() + "-cri";
        criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
        criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);

        // Now we notify all the head operators to count the criteria streams.
        setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
        setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());

        return criteriaTails;
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStream<?> mergeCriteriaHeadAndCriteriaStream(
            StreamExecutionEnvironment env,
            DataStream<?> head,
            DataStream<?> criteriaStream,
            TypeInformation<?> criteriaStreamType) {
        DraftExecutionEnvironment criteriaDraftEnv =
                new DraftExecutionEnvironment(env, new AllRoundOperatorWrapper<>());
        DataStream draftHeadStream = criteriaDraftEnv.addDraftSource(head, criteriaStreamType);
        DataStream draftTerminationCriteria =
                criteriaDraftEnv.addDraftSource(criteriaStream, criteriaStreamType);
        DataStream draftMergedStream =
                draftHeadStream
                        .connect(draftTerminationCriteria)
                        .process(new CriteriaMergeProcessor())
                        .returns(criteriaStreamType)
                        .setParallelism(
                                criteriaStream.getParallelism() > 0
                                        ? criteriaStream.getParallelism()
                                        : env.getConfig().getParallelism())
                        .name("criteria-merge");
        criteriaDraftEnv.copyToActualEnvironment();
        return criteriaDraftEnv.getActualStream(draftMergedStream.getId());
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStream<Integer> unionAllTails(
            StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) {

        return Iterations.<DataStream>map(
                        tailsAndCriteriaTails,
                        tail ->
                                tail.filter(r -> false)
                                        .name("filter-tail")
                                        .returns((TypeInformation) Types.INT)
                                        .setParallelism(
                                                tail.getParallelism() > 0
                                                        ? tail.getParallelism()
                                                        : env.getConfig().getParallelism()))
                .stream()
                .reduce(DataStream::union)
                .get();
    }

    private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
        return map(dataStreams, DataStream::getType);
    }

    private static DataStreamList addInputs(DataStreamList dataStreams) {
        return new DataStreamList(
                map(
                        dataStreams,
                        dataStream ->
                                dataStream
                                        .transform(
                                                "input-" + dataStream.getTransformation().getName(),
                                                new IterationRecordTypeInfo<>(dataStream.getType()),
                                                new InputOperator())
                                        .setParallelism(dataStream.getParallelism())));
    }

    private static DataStreamList addHeads(
            DataStreamList variableStreams,
            DataStreamList inputStreams,
            IterationID iterationId,
            int totalInitVariableParallelism,
            boolean isCriteriaStream,
            int startHeaderIndex) {

        return new DataStreamList(
                map(
                        inputStreams,
                        (index, dataStream) ->
                                ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
                                        .transform(
                                                "head-"
                                                        + variableStreams
                                                                .get(index)
                                                                .getTransformation()
                                                                .getName(),
                                                (IterationRecordTypeInfo) dataStream.getType(),
                                                new HeadOperatorFactory(
                                                        iterationId,
                                                        startHeaderIndex + index,
                                                        isCriteriaStream,
                                                        totalInitVariableParallelism))
                                        .setParallelism(dataStream.getParallelism())));
    }

    private static DataStreamList addTails(
            DataStreamList dataStreams, IterationID iterationId, int startIndex) {
        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) -> {
                            Transformation<?> inputTransformation = dataStream.getTransformation();
                            if (!(inputTransformation instanceof PhysicalTransformation)
                                    && inputTransformation.getInputs().size() > 1) {
                                // TODO: Support epoch watermark alignment for TailOperator.
                                throw new UnsupportedOperationException(
                                        "Tail operator should have only one input. Please check whether operator \""
                                                + inputTransformation.getName()
                                                + "\" contains multiple inputs.");
                            }

                            return ((DataStream<IterationRecord<?>>) dataStream)
                                    .transform(
                                            "tail-" + dataStream.getTransformation().getName(),
                                            new IterationRecordTypeInfo(dataStream.getType()),
                                            new TailOperator(iterationId, startIndex + index))
                                    .setParallelism(dataStream.getParallelism());
                        }));
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) {
        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) -> {
                            IterationRecordTypeInfo<?> inputType =
                                    (IterationRecordTypeInfo<?>) dataStream.getType();
                            return dataStream
                                    .union(
                                            tailsUnion
                                                    .map(x -> x)
                                                    .name(
                                                            "tail-map-"
                                                                    + dataStream
                                                                            .getTransformation()
                                                                            .getName())
                                                    .returns(inputType)
                                                    .setParallelism(1))
                                    .transform(
                                            "output-" + dataStream.getTransformation().getName(),
                                            inputType.getInnerTypeInfo(),
                                            new OutputOperator())
                                    .setParallelism(dataStream.getParallelism());
                        }));
    }

    private static DataStreamList addDraftSources(
            DataStreamList dataStreams,
            DraftExecutionEnvironment draftEnv,
            List<TypeInformation<?>> typeInfos) {

        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) ->
                                draftEnv.addDraftSource(dataStream, typeInfos.get(index))));
    }

    private static void ensuresTransformationAdded(
            DataStreamList dataStreams, DraftExecutionEnvironment draftEnv) {
        map(
                dataStreams,
                dataStream -> {
                    draftEnv.addOperatorIfNotExists(dataStream.getTransformation());
                    return null;
                });
    }

    private static void setCriteriaParallelism(
            DataStreamList headStreams, int criteriaParallelism) {
        map(
                headStreams,
                dataStream -> {
                    ((HeadOperatorFactory)
                                    ((OneInputTransformation) dataStream.getTransformation())
                                            .getOperatorFactory())
                            .setCriteriaStreamParallelism(criteriaParallelism);
                    return null;
                });
    }

    private static DataStreamList getActualDataStreams(
            DataStreamList draftStreams, DraftExecutionEnvironment draftEnv) {
        return new DataStreamList(
                map(draftStreams, dataStream -> draftEnv.getActualStream(dataStream.getId())));
    }

    private static <R> List<R> map(DataStreamList dataStreams, Function<DataStream<?>, R> mapper) {
        return map(dataStreams, (i, dataStream) -> mapper.apply(dataStream));
    }

    private static <R> List<R> map(
            DataStreamList dataStreams, BiFunction<Integer, DataStream<?>, R> mapper) {
        List<R> results = new ArrayList<>(dataStreams.size());
        for (int i = 0; i < dataStreams.size(); ++i) {
            DataStream<?> dataStream = dataStreams.get(i);
            results.add(mapper.apply(i, dataStream));
        }

        return results;
    }

    private static class CriteriaMergeProcessor extends CoProcessFunction<Object, Object, Object> {

        @Override
        public void processElement1(Object value, Context ctx, Collector<Object> out)
                throws Exception {
            // Ignores all the records from the head side-output.
        }

        @Override
        public void processElement2(Object value, Context ctx, Collector<Object> out)
                throws Exception {
            // Bypasses all the records from the actual criteria stream.
            out.collect(value);
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java [108:596]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@Experimental
public class Iterations {

    /**
     * This method uses an iteration body to process records in possibly unbounded data streams. The
     * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refer in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
        return createIteration(
                initVariableStreams,
                dataStreams,
                Collections.emptySet(),
                body,
                new AllRoundOperatorWrapper(),
                false);
    }

    /**
     * This method uses an iteration body to process records in some bounded data streams
     * iteratively until no more records are iterating or the given terminating criteria stream is
     * empty in one round.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refer in the {@code body} and if each of
     *     them needs replayed for each round.
     * @param config The config for the iteration, like whether to re-create the operator on each
     *     round.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateBoundedStreamsUntilTermination(
            DataStreamList initVariableStreams,
            ReplayableDataStreamList dataStreams,
            IterationConfig config,
            IterationBody body) {
        OperatorWrapper wrapper =
                config.getOperatorLifeCycle() == IterationConfig.OperatorLifeCycle.ALL_ROUND
                        ? new AllRoundOperatorWrapper<>()
                        : new PerRoundOperatorWrapper<>();

        List<DataStream<?>> allDatastreams = new ArrayList<>();
        allDatastreams.addAll(dataStreams.getReplayedDataStreams());
        allDatastreams.addAll(dataStreams.getNonReplayedStreams());

        Set<Integer> replayedIndices =
                IntStream.range(0, dataStreams.getReplayedDataStreams().size())
                        .boxed()
                        .collect(Collectors.toSet());

        return createIteration(
                initVariableStreams,
                new DataStreamList(allDatastreams),
                replayedIndices,
                body,
                wrapper,
                true);
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStreamList createIteration(
            DataStreamList initVariableStreams,
            DataStreamList dataStreams,
            Set<Integer> replayedDataStreamIndices,
            IterationBody body,
            OperatorWrapper<?, IterationRecord<?>> initialOperatorWrapper,
            boolean mayHaveCriteria) {
        checkState(initVariableStreams.size() > 0, "There should be at least one variable stream");

        IterationID iterationId = new IterationID();

        List<TypeInformation<?>> initVariableTypeInfos = getTypeInfos(initVariableStreams);
        List<TypeInformation<?>> dataStreamTypeInfos = getTypeInfos(dataStreams);

        // Add heads and inputs
        int totalInitVariableParallelism =
                map(
                                initVariableStreams,
                                dataStream ->
                                        dataStream.getParallelism() > 0
                                                ? dataStream.getParallelism()
                                                : dataStream
                                                        .getExecutionEnvironment()
                                                        .getConfig()
                                                        .getParallelism())
                        .stream()
                        .mapToInt(i -> i)
                        .sum();
        DataStreamList initVariableInputs = addInputs(initVariableStreams);
        DataStreamList headStreams =
                addHeads(
                        initVariableStreams,
                        initVariableInputs,
                        iterationId,
                        totalInitVariableParallelism,
                        false,
                        0);

        DataStreamList dataStreamInputs = addInputs(dataStreams);
        if (replayedDataStreamIndices.size() > 0) {
            dataStreamInputs =
                    addReplayer(
                            headStreams.get(0),
                            dataStreams,
                            dataStreamInputs,
                            replayedDataStreamIndices);
        }

        // Creates the iteration body. We map the inputs of iteration body into the draft sources,
        // which serve as the start points to build the draft subgraph.
        StreamExecutionEnvironment env = initVariableStreams.get(0).getExecutionEnvironment();
        DraftExecutionEnvironment draftEnv =
                new DraftExecutionEnvironment(env, initialOperatorWrapper);
        DataStreamList draftHeadStreams =
                addDraftSources(headStreams, draftEnv, initVariableTypeInfos);
        DataStreamList draftDataStreamInputs =
                addDraftSources(dataStreamInputs, draftEnv, dataStreamTypeInfos);

        IterationBodyResult iterationBodyResult =
                body.process(draftHeadStreams, draftDataStreamInputs);
        ensuresTransformationAdded(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        ensuresTransformationAdded(iterationBodyResult.getOutputStreams(), draftEnv);
        draftEnv.copyToActualEnvironment();

        // Adds tails and co-locate them with the heads.
        DataStreamList feedbackStreams =
                getActualDataStreams(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        checkState(
                feedbackStreams.size() == initVariableStreams.size(),
                "The number of feedback streams "
                        + feedbackStreams.size()
                        + " does not match the initialized one "
                        + initVariableStreams.size());
        for (int i = 0; i < feedbackStreams.size(); ++i) {
            checkState(
                    feedbackStreams.get(i).getParallelism() == headStreams.get(i).getParallelism(),
                    String.format(
                            "The feedback stream %d have different parallelism %d with the initial stream, which is %d",
                            i,
                            feedbackStreams.get(i).getParallelism(),
                            headStreams.get(i).getParallelism()));
        }

        DataStreamList tails = addTails(feedbackStreams, iterationId, 0);
        for (int i = 0; i < headStreams.size(); ++i) {
            String coLocationGroupKey = "co-" + iterationId.toHexString() + "-" + i;
            headStreams.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
            tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
        }

        List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
        checkState(
                mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
                "The current iteration type does not support the termination criteria.");

        if (iterationBodyResult.getTerminationCriteria() != null) {
            DataStreamList criteriaTails =
                    addCriteriaStream(
                            iterationBodyResult.getTerminationCriteria(),
                            iterationId,
                            env,
                            draftEnv,
                            initVariableStreams,
                            headStreams,
                            totalInitVariableParallelism);
            tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
        }

        DataStream<Integer> tailsUnion =
                unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));

        return addOutputs(
                getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
    }

    private static DataStreamList addReplayer(
            DataStream<?> firstHeadStream,
            DataStreamList originalDataStreams,
            DataStreamList dataStreamInputs,
            Set<Integer> replayedDataStreamIndices) {

        List<DataStream<?>> result = new ArrayList<>(dataStreamInputs.size());
        for (int i = 0; i < dataStreamInputs.size(); ++i) {
            if (!replayedDataStreamIndices.contains(i)) {
                result.add(dataStreamInputs.get(i));
                continue;
            }

            // Notes that the HeadOperator would broadcast the globally aligned events,
            // thus the operator does not require emit to the sideoutput specially.
            DataStream<?> replayedInput =
                    dataStreamInputs
                            .get(i)
                            .connect(
                                    ((SingleOutputStreamOperator<IterationRecord<?>>)
                                                    firstHeadStream)
                                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
                                            .broadcast())
                            .transform(
                                    "Replayer-"
                                            + originalDataStreams
                                                    .get(i)
                                                    .getTransformation()
                                                    .getName(),
                                    dataStreamInputs.get(i).getType(),
                                    (TwoInputStreamOperator) new ReplayOperator<>())
                            .setParallelism(dataStreamInputs.get(i).getParallelism());
            result.add(replayedInput);
        }

        return new DataStreamList(result);
    }

    private static DataStreamList addCriteriaStream(
            DataStream<?> draftCriteriaStream,
            IterationID iterationId,
            StreamExecutionEnvironment env,
            DraftExecutionEnvironment draftEnv,
            DataStreamList initVariableStreams,
            DataStreamList headStreams,
            int totalInitVariableParallelism) {
        // Deals with the criteria streams
        DataStream<?> terminationCriteria = draftEnv.getActualStream(draftCriteriaStream.getId());
        // It should always has the IterationRecordTypeInfo
        checkState(
                terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
                "The termination criteria should always return IterationRecord.");
        TypeInformation<?> innerType =
                ((IterationRecordTypeInfo<?>) terminationCriteria.getType()).getInnerTypeInfo();

        DataStream<?> emptyCriteriaSource =
                env.addSource(new DraftExecutionEnvironment.EmptySource())
                        .returns(innerType)
                        .name(terminationCriteria.getTransformation().getName())
                        .setParallelism(terminationCriteria.getParallelism());
        DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
        DataStreamList criteriaInputs = addInputs(criteriaSources);
        DataStreamList criteriaHeaders =
                addHeads(
                        criteriaSources,
                        criteriaInputs,
                        iterationId,
                        totalInitVariableParallelism,
                        true,
                        initVariableStreams.size());

        // Merges the head and the actual criteria stream. This is required since if we have
        // no edges from the criteria head to the criteria tail, the tail might directly received
        // the MAX_EPOCH_WATERMARK without the synchronization of the head.
        DataStream<?> mergedHeadAndCriteria =
                mergeCriteriaHeadAndCriteriaStream(
                        env, criteriaHeaders.get(0), terminationCriteria, innerType);
        DataStreamList criteriaTails =
                addTails(
                        DataStreamList.of(mergedHeadAndCriteria),
                        iterationId,
                        initVariableStreams.size());

        String coLocationGroupKey = "co-" + iterationId.toHexString() + "-cri";
        criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
        criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);

        // Now we notify all the head operators to count the criteria streams.
        setCriteriaParallelism(headStreams, terminationCriteria.getParallelism());
        setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism());

        return criteriaTails;
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStream<?> mergeCriteriaHeadAndCriteriaStream(
            StreamExecutionEnvironment env,
            DataStream<?> head,
            DataStream<?> criteriaStream,
            TypeInformation<?> criteriaStreamType) {
        DraftExecutionEnvironment criteriaDraftEnv =
                new DraftExecutionEnvironment(env, new AllRoundOperatorWrapper<>());
        DataStream draftHeadStream = criteriaDraftEnv.addDraftSource(head, criteriaStreamType);
        DataStream draftTerminationCriteria =
                criteriaDraftEnv.addDraftSource(criteriaStream, criteriaStreamType);
        DataStream draftMergedStream =
                draftHeadStream
                        .connect(draftTerminationCriteria)
                        .process(new CriteriaMergeProcessor())
                        .returns(criteriaStreamType)
                        .setParallelism(
                                criteriaStream.getParallelism() > 0
                                        ? criteriaStream.getParallelism()
                                        : env.getConfig().getParallelism())
                        .name("criteria-merge");
        criteriaDraftEnv.copyToActualEnvironment();
        return criteriaDraftEnv.getActualStream(draftMergedStream.getId());
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStream<Integer> unionAllTails(
            StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) {

        return Iterations.<DataStream>map(
                        tailsAndCriteriaTails,
                        tail ->
                                tail.filter(r -> false)
                                        .name("filter-tail")
                                        .returns((TypeInformation) Types.INT)
                                        .setParallelism(
                                                tail.getParallelism() > 0
                                                        ? tail.getParallelism()
                                                        : env.getConfig().getParallelism()))
                .stream()
                .reduce(DataStream::union)
                .get();
    }

    private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) {
        return map(dataStreams, DataStream::getType);
    }

    private static DataStreamList addInputs(DataStreamList dataStreams) {
        return new DataStreamList(
                map(
                        dataStreams,
                        dataStream ->
                                dataStream
                                        .transform(
                                                "input-" + dataStream.getTransformation().getName(),
                                                new IterationRecordTypeInfo<>(dataStream.getType()),
                                                new InputOperator())
                                        .setParallelism(dataStream.getParallelism())));
    }

    private static DataStreamList addHeads(
            DataStreamList variableStreams,
            DataStreamList inputStreams,
            IterationID iterationId,
            int totalInitVariableParallelism,
            boolean isCriteriaStream,
            int startHeaderIndex) {

        return new DataStreamList(
                map(
                        inputStreams,
                        (index, dataStream) ->
                                ((SingleOutputStreamOperator<IterationRecord<?>>) dataStream)
                                        .transform(
                                                "head-"
                                                        + variableStreams
                                                                .get(index)
                                                                .getTransformation()
                                                                .getName(),
                                                (IterationRecordTypeInfo) dataStream.getType(),
                                                new HeadOperatorFactory(
                                                        iterationId,
                                                        startHeaderIndex + index,
                                                        isCriteriaStream,
                                                        totalInitVariableParallelism))
                                        .setParallelism(dataStream.getParallelism())));
    }

    private static DataStreamList addTails(
            DataStreamList dataStreams, IterationID iterationId, int startIndex) {
        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) -> {
                            Transformation<?> inputTransformation = dataStream.getTransformation();
                            if (!(inputTransformation instanceof PhysicalTransformation)
                                    && inputTransformation.getInputs().size() > 1) {
                                // TODO: Support epoch watermark alignment for TailOperator.
                                throw new UnsupportedOperationException(
                                        "Tail operator should have only one input. Please check whether operator \""
                                                + inputTransformation.getName()
                                                + "\" contains multiple inputs.");
                            }

                            return ((DataStream<IterationRecord<?>>) dataStream)
                                    .transform(
                                            "tail-" + dataStream.getTransformation().getName(),
                                            new IterationRecordTypeInfo(dataStream.getType()),
                                            new TailOperator(iterationId, startIndex + index))
                                    .setParallelism(dataStream.getParallelism());
                        }));
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) {
        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) -> {
                            IterationRecordTypeInfo<?> inputType =
                                    (IterationRecordTypeInfo<?>) dataStream.getType();
                            return dataStream
                                    .union(
                                            tailsUnion
                                                    .map(x -> x)
                                                    .name(
                                                            "tail-map-"
                                                                    + dataStream
                                                                            .getTransformation()
                                                                            .getName())
                                                    .returns(inputType)
                                                    .setParallelism(1))
                                    .transform(
                                            "output-" + dataStream.getTransformation().getName(),
                                            inputType.getInnerTypeInfo(),
                                            new OutputOperator())
                                    .setParallelism(dataStream.getParallelism());
                        }));
    }

    private static DataStreamList addDraftSources(
            DataStreamList dataStreams,
            DraftExecutionEnvironment draftEnv,
            List<TypeInformation<?>> typeInfos) {

        return new DataStreamList(
                map(
                        dataStreams,
                        (index, dataStream) ->
                                draftEnv.addDraftSource(dataStream, typeInfos.get(index))));
    }

    private static void ensuresTransformationAdded(
            DataStreamList dataStreams, DraftExecutionEnvironment draftEnv) {
        map(
                dataStreams,
                dataStream -> {
                    draftEnv.addOperatorIfNotExists(dataStream.getTransformation());
                    return null;
                });
    }

    private static void setCriteriaParallelism(
            DataStreamList headStreams, int criteriaParallelism) {
        map(
                headStreams,
                dataStream -> {
                    ((HeadOperatorFactory)
                                    ((OneInputTransformation) dataStream.getTransformation())
                                            .getOperatorFactory())
                            .setCriteriaStreamParallelism(criteriaParallelism);
                    return null;
                });
    }

    private static DataStreamList getActualDataStreams(
            DataStreamList draftStreams, DraftExecutionEnvironment draftEnv) {
        return new DataStreamList(
                map(draftStreams, dataStream -> draftEnv.getActualStream(dataStream.getId())));
    }

    private static <R> List<R> map(DataStreamList dataStreams, Function<DataStream<?>, R> mapper) {
        return map(dataStreams, (i, dataStream) -> mapper.apply(dataStream));
    }

    private static <R> List<R> map(
            DataStreamList dataStreams, BiFunction<Integer, DataStream<?>, R> mapper) {
        List<R> results = new ArrayList<>(dataStreams.size());
        for (int i = 0; i < dataStreams.size(); ++i) {
            DataStream<?> dataStream = dataStreams.get(i);
            results.add(mapper.apply(i, dataStream));
        }

        return results;
    }

    private static class CriteriaMergeProcessor extends CoProcessFunction<Object, Object, Object> {

        @Override
        public void processElement1(Object value, Context ctx, Collector<Object> out)
                throws Exception {
            // Ignores all the records from the head side-output.
        }

        @Override
        public void processElement2(Object value, Context ctx, Collector<Object> out)
                throws Exception {
            // Bypasses all the records from the actual criteria stream.
            out.collect(value);
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



