public IRDAG apply()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java [97:167]


  public IRDAG apply(final IRDAG dag) {
    final Map<IRVertex, Integer> irVertexToGroupIdMap = new HashMap<>();
    final Map<Integer, List<IRVertex>> groupIdToVertices = new HashMap<>();

    // Step 1: Compute schedule groups
    final MutableInt lastGroupId = new MutableInt(0);
    dag.topologicalDo(irVertex -> {
      final int curId;
      if (!irVertexToGroupIdMap.containsKey(irVertex)) {
        lastGroupId.increment();
        irVertexToGroupIdMap.put(irVertex, lastGroupId.intValue());
        curId = lastGroupId.intValue();
      } else {
        curId = irVertexToGroupIdMap.get(irVertex);
      }
      groupIdToVertices.putIfAbsent(curId, new ArrayList<>());
      groupIdToVertices.get(curId).add(irVertex);

      final List<IRVertex> verticesOfGroup = groupIdToVertices.get(curId);
      final List<IREdge> allOutEdgesOfGroup = groupIdToVertices.get(curId).stream()
        .flatMap(vtx -> dag.getOutgoingEdgesOf(vtx).stream())
        .filter(edge -> !verticesOfGroup.contains(edge.getDst()))  // We don't count the group-internal edges.
        .collect(Collectors.toList());
      final List<IREdge> noCycleOutEdges = allOutEdgesOfGroup.stream().filter(curEdge -> {
        final List<IREdge> outgoingEdgesWithoutCurEdge = new ArrayList<>(allOutEdgesOfGroup);
        outgoingEdgesWithoutCurEdge.remove(curEdge);
        return outgoingEdgesWithoutCurEdge.stream()
          .map(IREdge::getDst)
          .flatMap(dst -> dag.getDescendants(dst.getId()).stream())
          .noneMatch(descendant -> descendant.equals(curEdge.getDst()));
      }).collect(Collectors.toList());

      final List<IRVertex> pushNoCycleOutEdgeDsts = noCycleOutEdges.stream()
        .filter(e -> DataFlowProperty.Value.PUSH.equals(e.getPropertyValue(DataFlowProperty.class).get()))
        .map(IREdge::getDst)
        .collect(Collectors.toList());

      pushNoCycleOutEdgeDsts.forEach(dst -> irVertexToGroupIdMap.put(dst, curId));
    });

    // Step 2: Topologically sort schedule groups
    final DAGBuilder<ScheduleGroup, ScheduleGroupEdge> builder = new DAGBuilder<>();
    final Map<Integer, ScheduleGroup> idToGroup = new HashMap<>();

    // ScheduleGroups
    groupIdToVertices.forEach((groupId, vertices) -> {
      final ScheduleGroup sg = new ScheduleGroup(groupId);
      idToGroup.put(groupId, sg);
      sg.vertices.addAll(vertices);
      builder.addVertex(sg);
    });

    // ScheduleGroupEdges
    irVertexToGroupIdMap.forEach((vertex, groupId) -> dag.getIncomingEdgesOf(vertex).stream()
      .filter(inEdge -> !groupIdToVertices.get(groupId).contains(inEdge.getSrc()))
      .map(inEdge -> new ScheduleGroupEdge(
        idToGroup.get(irVertexToGroupIdMap.get(inEdge.getSrc())),
        idToGroup.get(irVertexToGroupIdMap.get(inEdge.getDst()))))
      .forEach(builder::connectVertices));

    // Step 3: Actually set new schedule group properties based on topological ordering
    final MutableInt actualScheduleGroup = new MutableInt(0);
    final DAG<ScheduleGroup, ScheduleGroupEdge> sgDAG = builder.buildWithoutSourceSinkCheck();
    sgDAG.topologicalDo(sg -> {
      sg.vertices.forEach(vertex ->
        vertex.setPropertyPermanently(ScheduleGroupProperty.of(actualScheduleGroup.intValue())));
      actualScheduleGroup.increment();
    });

    return dag;
  }