public DAG stagePartitionIrDAG()

in runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java [133:241]


  public DAG<Stage, StageEdge> stagePartitionIrDAG(final IRDAG irDAG) {
    final StagePartitioner stagePartitioner = new StagePartitioner();
    final DAGBuilder<Stage, StageEdge> dagOfStagesBuilder = new DAGBuilder<>();
    final Set<IREdge> interStageEdges = new HashSet<>();
    final Map<Integer, Stage> stageIdToStageMap = new HashMap<>();
    final Map<IRVertex, Integer> vertexToStageIdMap = stagePartitioner.apply(irDAG);
    final HashSet<IRVertex> isStagePartitioned = new HashSet<>();
    final Random random = new Random(hashCode()); // to produce same results for same input IRDAGs

    final Map<Integer, Set<IRVertex>> vertexSetForEachStage = new LinkedHashMap<>();
    irDAG.topologicalDo(irVertex -> {
      final int stageId = vertexToStageIdMap.get(irVertex);
      if (!vertexSetForEachStage.containsKey(stageId)) {
        vertexSetForEachStage.put(stageId, new HashSet<>());
      }
      vertexSetForEachStage.get(stageId).add(irVertex);
    });

    for (final int stageId : vertexSetForEachStage.keySet()) {
      final Set<IRVertex> stageVertices = vertexSetForEachStage.get(stageId);
      final String stageIdentifier = RuntimeIdManager.generateStageId(stageId);
      final ExecutionPropertyMap<VertexExecutionProperty> stageProperties = new ExecutionPropertyMap<>(stageIdentifier);
      stagePartitioner.getStageProperties(stageVertices.iterator().next()).forEach(stageProperties::put);
      final int stageParallelism = stageProperties.get(ParallelismProperty.class)
        .orElseThrow(() -> new RuntimeException("Parallelism property must be set for Stage"));
      final List<Integer> taskIndices = getTaskIndicesToExecute(stageVertices, stageParallelism, random);
      final DAGBuilder<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAGBuilder = new DAGBuilder<>();

      // Prepare vertexIdToReadables
      final List<Map<String, Readable>> vertexIdToReadables = new ArrayList<>(stageParallelism);
      for (int i = 0; i < stageParallelism; i++) {
        vertexIdToReadables.add(new HashMap<>());
      }

      // For each IRVertex,
      for (final IRVertex v : stageVertices) {
        final IRVertex vertexToPutIntoStage = getActualVertexToPutIntoStage(v);

        // Take care of the readables of a source vertex.
        if (vertexToPutIntoStage instanceof SourceVertex && !isStagePartitioned.contains(vertexToPutIntoStage)) {
          final SourceVertex sourceVertex = (SourceVertex) vertexToPutIntoStage;
          try {
            final List<Readable> readables = sourceVertex.getReadables(stageParallelism);
            for (int i = 0; i < stageParallelism; i++) {
              vertexIdToReadables.get(i).put(vertexToPutIntoStage.getId(), readables.get(i));
            }
          } catch (final Exception e) {
            throw new PhysicalPlanGenerationException(e);
          }
          // Clear internal metadata.
          sourceVertex.clearInternalStates();
        }

        // Add vertex to the stage.
        stageInternalDAGBuilder.addVertex(vertexToPutIntoStage);
      }

      for (final IRVertex dstVertex : stageVertices) {
        // Connect all the incoming edges for the vertex.
        irDAG.getIncomingEdgesOf(dstVertex).forEach(irEdge -> {
          final IRVertex srcVertex = irEdge.getSrc();

          // both vertices are in the same stage.
          if (vertexToStageIdMap.get(srcVertex).equals(vertexToStageIdMap.get(dstVertex))) {
            stageInternalDAGBuilder.connectVertices(new RuntimeEdge<>(
              irEdge.getId(),
              irEdge.getExecutionProperties(),
              getActualVertexToPutIntoStage(irEdge.getSrc()),
              getActualVertexToPutIntoStage(irEdge.getDst())));
          } else { // edge comes from another stage
            interStageEdges.add(irEdge);
          }
        });
      }
      // If this runtime stage contains at least one vertex, build it!
      if (!stageInternalDAGBuilder.isEmpty()) {
        final DAG<IRVertex, RuntimeEdge<IRVertex>> stageInternalDAG
          = stageInternalDAGBuilder.buildWithoutSourceSinkCheck();
        final Stage stage = new Stage(
          stageIdentifier,
          taskIndices,
          stageInternalDAG,
          stageProperties,
          vertexIdToReadables);
        dagOfStagesBuilder.addVertex(stage);
        stageIdToStageMap.put(stageId, stage);
      }

      // To prevent re-fetching readables in source vertex
      // during re-generation of physical plan for dynamic optimization.
      isStagePartitioned.addAll(stageVertices);
    }

    // Add StageEdges
    for (final IREdge interStageEdge : interStageEdges) {
      final Stage srcStage = stageIdToStageMap.get(vertexToStageIdMap.get(interStageEdge.getSrc()));
      final Stage dstStage = stageIdToStageMap.get(vertexToStageIdMap.get(interStageEdge.getDst()));
      if (srcStage == null || dstStage == null) {
        throw new IllegalVertexOperationException(String.format("Stage not added to the builder:%s%s",
          srcStage == null ? String.format(" source stage for %s", interStageEdge.getSrc()) : "",
          dstStage == null ? String.format(" destination stage for %s", interStageEdge.getDst()) : ""));
      }
      dagOfStagesBuilder.connectVertices(new StageEdge(interStageEdge.getId(), interStageEdge.getExecutionProperties(),
        getActualVertexToPutIntoStage(interStageEdge.getSrc()), getActualVertexToPutIntoStage(interStageEdge.getDst()),
        srcStage, dstStage));
    }

    return dagOfStagesBuilder.build();
  }