in runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java [133:241]
public DAG<Stage, StageEdge> stagePartitionIrDAG(final IRDAG irDAG) {
final StagePartitioner stagePartitioner = new StagePartitioner();
final DAGBuilder<Stage, StageEdge> dagOfStagesBuilder = new DAGBuilder<>();
final Set<IREdge> interStageEdges = new HashSet<>();
final Map<Integer, Stage> stageIdToStageMap = new HashMap<>();
final Map<IRVertex, Integer> vertexToStageIdMap = stagePartitioner.apply(irDAG);
final HashSet<IRVertex> isStagePartitioned = new HashSet<>();
final Random random = new Random(hashCode()); // to produce same results for same input IRDAGs
final Map<Integer, Set<IRVertex>> vertexSetForEachStage = new LinkedHashMap<>();
irDAG.topologicalDo(irVertex -> {
final int stageId = vertexToStageIdMap.get(irVertex);
if (!vertexSetForEachStage.containsKey(stageId)) {
vertexSetForEachStage.put(stageId, new HashSet<>());
}
vertexSetForEachStage.get(stageId).add(irVertex);
});
for (final int stageId : vertexSetForEachStage.keySet()) {
final Set<IRVertex> stageVertices = vertexSetForEachStage.get(stageId);
final String stageIdentifier = RuntimeIdManager.generateStageId(stageId);
final ExecutionPropertyMap<VertexExecutionProperty> stageProperties = new ExecutionPropertyMap<>(stageIdentifier);
stagePartitioner.getStageProperties(stageVertices.iterator().next()).forEach(stageProperties::put);
final int stageParallelism = stageProperties.get(ParallelismProperty.class)
.orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
final List<Integer> taskIndices = getTaskIndicesToExecute(stageVertices, stageParallelism, random);
final DAGBuilder<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAGBuilder = new DAGBuilder<>();
// Prepare vertexIdToReadables
final List<Map<String, Readable>> vertexIdToReadables = new ArrayList<>(stageParallelism);
for (int i = 0; i < stageParallelism; i++) {
vertexIdToReadables.add(new HashMap<>());
}
// For each IRVertex,
for (final IRVertex v : stageVertices) {
final IRVertex vertexToPutIntoStage = getActualVertexToPutIntoStage(v);
// Take care of the readables of a source vertex.
if (vertexToPutIntoStage instanceof SourceVertex && !isStagePartitioned.contains(vertexToPutIntoStage)) {
final SourceVertex sourceVertex = (SourceVertex) vertexToPutIntoStage;
try {
final List<Readable> readables = sourceVertex.getReadables(stageParallelism);
for (int i = 0; i < stageParallelism; i++) {
vertexIdToReadables.get(i).put(vertexToPutIntoStage.getId(), readables.get(i));
}
} catch (final Exception e) {
throw new PhysicalPlanGenerationException(e);
}
// Clear internal metadata.
sourceVertex.clearInternalStates();
}
// Add vertex to the stage.
stageInternalDAGBuilder.addVertex(vertexToPutIntoStage);
}
for (final IRVertex dstVertex : stageVertices) {
// Connect all the incoming edges for the vertex.
irDAG.getIncomingEdgesOf(dstVertex).forEach(irEdge -> {
final IRVertex srcVertex = irEdge.getSrc();
// both vertices are in the same stage.
if (vertexToStageIdMap.get(srcVertex).equals(vertexToStageIdMap.get(dstVertex))) {
stageInternalDAGBuilder.connectVertices(new RuntimeEdge<>(
irEdge.getId(),
irEdge.getExecutionProperties(),
getActualVertexToPutIntoStage(irEdge.getSrc()),
getActualVertexToPutIntoStage(irEdge.getDst())));
} else { // edge comes from another stage
interStageEdges.add(irEdge);
}
});
}
// If this runtime stage contains at least one vertex, build it!
if (!stageInternalDAGBuilder.isEmpty()) {
final DAG<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAG
= stageInternalDAGBuilder.buildWithoutSourceSinkCheck();
final Stage stage = new Stage(
stageIdentifier,
taskIndices,
stageInternalDAG,
stageProperties,
vertexIdToReadables);
dagOfStagesBuilder.addVertex(stage);
stageIdToStageMap.put(stageId, stage);
}
// To prevent re-fetching readables in source vertex
// during re-generation of physical plan for dynamic optimization.
isStagePartitioned.addAll(stageVertices);
}
// Add StageEdges
for (final IREdge interStageEdge : interStageEdges) {
final Stage srcStage = stageIdToStageMap.get(vertexToStageIdMap.get(interStageEdge.getSrc()));
final Stage dstStage = stageIdToStageMap.get(vertexToStageIdMap.get(interStageEdge.getDst()));
if (srcStage == null || dstStage == null) {
throw new IllegalVertexOperationException(String.format("Stage not added to the builder:%s%s",
srcStage == null ? String.format(" source stage for %s", interStageEdge.getSrc()) : "",
dstStage == null ? String.format(" destination stage for %s", interStageEdge.getDst()) : ""));
}
dagOfStagesBuilder.connectVertices(new StageEdge(interStageEdge.getId(), interStageEdge.getExecutionProperties(),
getActualVertexToPutIntoStage(interStageEdge.getSrc()), getActualVertexToPutIntoStage(interStageEdge.getDst()),
srcStage, dstStage));
}
return dagOfStagesBuilder.build();
}