private static DataStreamList createIteration()

in flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java [178:291]


    private static DataStreamList createIteration(
            DataStreamList initVariableStreams,
            DataStreamList dataStreams,
            Set<Integer> replayedDataStreamIndices,
            IterationBody body,
            OperatorWrapper<?, IterationRecord<?>> initialOperatorWrapper,
            boolean mayHaveCriteria) {
        checkState(initVariableStreams.size() > 0, "There should be at least one variable stream");

        IterationID iterationId = new IterationID();

        List<TypeInformation<?>> initVariableTypeInfos = getTypeInfos(initVariableStreams);
        List<TypeInformation<?>> dataStreamTypeInfos = getTypeInfos(dataStreams);

        // Add heads and inputs
        int totalInitVariableParallelism =
                map(
                                initVariableStreams,
                                dataStream ->
                                        dataStream.getParallelism() > 0
                                                ? dataStream.getParallelism()
                                                : dataStream
                                                        .getExecutionEnvironment()
                                                        .getConfig()
                                                        .getParallelism())
                        .stream()
                        .mapToInt(i -> i)
                        .sum();
        DataStreamList initVariableInputs = addInputs(initVariableStreams);
        DataStreamList headStreams =
                addHeads(
                        initVariableStreams,
                        initVariableInputs,
                        iterationId,
                        totalInitVariableParallelism,
                        false,
                        0);

        DataStreamList dataStreamInputs = addInputs(dataStreams);
        if (replayedDataStreamIndices.size() > 0) {
            dataStreamInputs =
                    addReplayer(
                            headStreams.get(0),
                            dataStreams,
                            dataStreamInputs,
                            replayedDataStreamIndices);
        }

        // Creates the iteration body. We map the inputs of iteration body into the draft sources,
        // which serve as the start points to build the draft subgraph.
        StreamExecutionEnvironment env = initVariableStreams.get(0).getExecutionEnvironment();
        DraftExecutionEnvironment draftEnv =
                new DraftExecutionEnvironment(env, initialOperatorWrapper);
        DataStreamList draftHeadStreams =
                addDraftSources(headStreams, draftEnv, initVariableTypeInfos);
        DataStreamList draftDataStreamInputs =
                addDraftSources(dataStreamInputs, draftEnv, dataStreamTypeInfos);

        IterationBodyResult iterationBodyResult =
                body.process(draftHeadStreams, draftDataStreamInputs);
        ensuresTransformationAdded(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        ensuresTransformationAdded(iterationBodyResult.getOutputStreams(), draftEnv);
        draftEnv.copyToActualEnvironment();

        // Adds tails and co-locate them with the heads.
        DataStreamList feedbackStreams =
                getActualDataStreams(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
        checkState(
                feedbackStreams.size() == initVariableStreams.size(),
                "The number of feedback streams "
                        + feedbackStreams.size()
                        + " does not match the initialized one "
                        + initVariableStreams.size());
        for (int i = 0; i < feedbackStreams.size(); ++i) {
            checkState(
                    feedbackStreams.get(i).getParallelism() == headStreams.get(i).getParallelism(),
                    String.format(
                            "The feedback stream %d have different parallelism %d with the initial stream, which is %d",
                            i,
                            feedbackStreams.get(i).getParallelism(),
                            headStreams.get(i).getParallelism()));
        }

        DataStreamList tails = addTails(feedbackStreams, iterationId, 0);
        for (int i = 0; i < headStreams.size(); ++i) {
            String coLocationGroupKey = "co-" + iterationId.toHexString() + "-" + i;
            headStreams.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
            tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
        }

        List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
        checkState(
                mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
                "The current iteration type does not support the termination criteria.");

        if (iterationBodyResult.getTerminationCriteria() != null) {
            DataStreamList criteriaTails =
                    addCriteriaStream(
                            iterationBodyResult.getTerminationCriteria(),
                            iterationId,
                            env,
                            draftEnv,
                            initVariableStreams,
                            headStreams,
                            totalInitVariableParallelism);
            tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
        }

        DataStream<Integer> tailsUnion =
                unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));

        return addOutputs(
                getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
    }