in flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/Iterations.java [178:291]
private static DataStreamList createIteration(
DataStreamList initVariableStreams,
DataStreamList dataStreams,
Set<Integer> replayedDataStreamIndices,
IterationBody body,
OperatorWrapper<?, IterationRecord<?>> initialOperatorWrapper,
boolean mayHaveCriteria) {
checkState(initVariableStreams.size() > 0, "There should be at least one variable stream");
IterationID iterationId = new IterationID();
List<TypeInformation<?>> initVariableTypeInfos = getTypeInfos(initVariableStreams);
List<TypeInformation<?>> dataStreamTypeInfos = getTypeInfos(dataStreams);
// Add heads and inputs
int totalInitVariableParallelism =
map(
initVariableStreams,
dataStream ->
dataStream.getParallelism() > 0
? dataStream.getParallelism()
: dataStream
.getExecutionEnvironment()
.getConfig()
.getParallelism())
.stream()
.mapToInt(i -> i)
.sum();
DataStreamList initVariableInputs = addInputs(initVariableStreams);
DataStreamList headStreams =
addHeads(
initVariableStreams,
initVariableInputs,
iterationId,
totalInitVariableParallelism,
false,
0);
DataStreamList dataStreamInputs = addInputs(dataStreams);
if (replayedDataStreamIndices.size() > 0) {
dataStreamInputs =
addReplayer(
headStreams.get(0),
dataStreams,
dataStreamInputs,
replayedDataStreamIndices);
}
// Creates the iteration body. We map the inputs of iteration body into the draft sources,
// which serve as the start points to build the draft subgraph.
StreamExecutionEnvironment env = initVariableStreams.get(0).getExecutionEnvironment();
DraftExecutionEnvironment draftEnv =
new DraftExecutionEnvironment(env, initialOperatorWrapper);
DataStreamList draftHeadStreams =
addDraftSources(headStreams, draftEnv, initVariableTypeInfos);
DataStreamList draftDataStreamInputs =
addDraftSources(dataStreamInputs, draftEnv, dataStreamTypeInfos);
IterationBodyResult iterationBodyResult =
body.process(draftHeadStreams, draftDataStreamInputs);
ensuresTransformationAdded(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
ensuresTransformationAdded(iterationBodyResult.getOutputStreams(), draftEnv);
draftEnv.copyToActualEnvironment();
// Adds tails and co-locate them with the heads.
DataStreamList feedbackStreams =
getActualDataStreams(iterationBodyResult.getFeedbackVariableStreams(), draftEnv);
checkState(
feedbackStreams.size() == initVariableStreams.size(),
"The number of feedback streams "
+ feedbackStreams.size()
+ " does not match the initialized one "
+ initVariableStreams.size());
for (int i = 0; i < feedbackStreams.size(); ++i) {
checkState(
feedbackStreams.get(i).getParallelism() == headStreams.get(i).getParallelism(),
String.format(
"The feedback stream %d have different parallelism %d with the initial stream, which is %d",
i,
feedbackStreams.get(i).getParallelism(),
headStreams.get(i).getParallelism()));
}
DataStreamList tails = addTails(feedbackStreams, iterationId, 0);
for (int i = 0; i < headStreams.size(); ++i) {
String coLocationGroupKey = "co-" + iterationId.toHexString() + "-" + i;
headStreams.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
}
List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams());
checkState(
mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null,
"The current iteration type does not support the termination criteria.");
if (iterationBodyResult.getTerminationCriteria() != null) {
DataStreamList criteriaTails =
addCriteriaStream(
iterationBodyResult.getTerminationCriteria(),
iterationId,
env,
draftEnv,
initVariableStreams,
headStreams,
totalInitVariableParallelism);
tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams());
}
DataStream<Integer> tailsUnion =
unionAllTails(env, new DataStreamList(tailsAndCriteriaTails));
return addOutputs(
getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion);
}