in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPass.java [97:167]
public IRDAG apply(final IRDAG dag) {
final Map<IRVertex, Integer> irVertexToGroupIdMap = new HashMap<>();
final Map<Integer, List<IRVertex>> groupIdToVertices = new HashMap<>();
// Step 1: Compute schedule groups
final MutableInt lastGroupId = new MutableInt(0);
dag.topologicalDo(irVertex -> {
final int curId;
if (!irVertexToGroupIdMap.containsKey(irVertex)) {
lastGroupId.increment();
irVertexToGroupIdMap.put(irVertex, lastGroupId.intValue());
curId = lastGroupId.intValue();
} else {
curId = irVertexToGroupIdMap.get(irVertex);
}
groupIdToVertices.putIfAbsent(curId, new ArrayList<>());
groupIdToVertices.get(curId).add(irVertex);
final List<IRVertex> verticesOfGroup = groupIdToVertices.get(curId);
final List<IREdge> allOutEdgesOfGroup = groupIdToVertices.get(curId).stream()
.flatMap(vtx -> dag.getOutgoingEdgesOf(vtx).stream())
.filter(edge -> !verticesOfGroup.contains(edge.getDst())) // We don't count the group-internal edges.
.collect(Collectors.toList());
final List<IREdge> noCycleOutEdges = allOutEdgesOfGroup.stream().filter(curEdge -> {
final List<IREdge> outgoingEdgesWithoutCurEdge = new ArrayList<>(allOutEdgesOfGroup);
outgoingEdgesWithoutCurEdge.remove(curEdge);
return outgoingEdgesWithoutCurEdge.stream()
.map(IREdge::getDst)
.flatMap(dst -> dag.getDescendants(dst.getId()).stream())
.noneMatch(descendant -> descendant.equals(curEdge.getDst()));
}).collect(Collectors.toList());
final List<IRVertex> pushNoCycleOutEdgeDsts = noCycleOutEdges.stream()
.filter(e -> DataFlowProperty.Value.PUSH.equals(e.getPropertyValue(DataFlowProperty.class).get()))
.map(IREdge::getDst)
.collect(Collectors.toList());
pushNoCycleOutEdgeDsts.forEach(dst -> irVertexToGroupIdMap.put(dst, curId));
});
// Step 2: Topologically sort schedule groups
final DAGBuilder<ScheduleGroup, ScheduleGroupEdge> builder = new DAGBuilder<>();
final Map<Integer, ScheduleGroup> idToGroup = new HashMap<>();
// ScheduleGroups
groupIdToVertices.forEach((groupId, vertices) -> {
final ScheduleGroup sg = new ScheduleGroup(groupId);
idToGroup.put(groupId, sg);
sg.vertices.addAll(vertices);
builder.addVertex(sg);
});
// ScheduleGroupEdges
irVertexToGroupIdMap.forEach((vertex, groupId) -> dag.getIncomingEdgesOf(vertex).stream()
.filter(inEdge -> !groupIdToVertices.get(groupId).contains(inEdge.getSrc()))
.map(inEdge -> new ScheduleGroupEdge(
idToGroup.get(irVertexToGroupIdMap.get(inEdge.getSrc())),
idToGroup.get(irVertexToGroupIdMap.get(inEdge.getDst()))))
.forEach(builder::connectVertices));
// Step 3: Actually set new schedule group properties based on topological ordering
final MutableInt actualScheduleGroup = new MutableInt(0);
final DAG<ScheduleGroup, ScheduleGroupEdge> sgDAG = builder.buildWithoutSourceSinkCheck();
sgDAG.topologicalDo(sg -> {
sg.vertices.forEach(vertex ->
vertex.setPropertyPermanently(ScheduleGroupProperty.of(actualScheduleGroup.intValue())));
actualScheduleGroup.increment();
});
return dag;
}