in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java [190:287]
private DAG<IRVertex, IREdge> loopRolling(final DAG<IRVertex, IREdge> dag) {
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
// Map for LoopVertex --> RootLoopVertex
final HashMap<LoopVertex, LoopVertex> loopVerticesOfSameLoop = new HashMap<>();
// RootLoopVertex --> Map of (RolledVertex --> (Root)Vertex)
final HashMap<LoopVertex, HashMap<IRVertex, IRVertex>> equivalentVerticesOfLoops = new HashMap<>();
// The RootLoopVertex that we're processing now.
LoopVertex rootLoopVertex = null;
// observe the DAG in a topological order.
for (IRVertex irVertex : dag.getTopologicalSort()) {
if (irVertex instanceof SourceVertex) { // source vertex
builder.addVertex(irVertex, dag);
} else if (irVertex instanceof OperatorVertex) { // operator vertex
addVertexToBuilder(builder, dag, irVertex, loopVerticesOfSameLoop);
} else if (irVertex instanceof LoopVertex) { // loop vertex: we roll them if it is not root
final LoopVertex loopVertex = (LoopVertex) irVertex;
if (rootLoopVertex == null || !loopVertex.getName().contains(rootLoopVertex.getName())) { // initial root loop
rootLoopVertex = loopVertex;
loopVerticesOfSameLoop.putIfAbsent(rootLoopVertex, rootLoopVertex);
equivalentVerticesOfLoops.putIfAbsent(rootLoopVertex, new HashMap<>());
// Add the initial vertices
for (IRVertex vertex : rootLoopVertex.getDAG().getTopologicalSort()) {
equivalentVerticesOfLoops.get(rootLoopVertex).putIfAbsent(vertex, vertex);
IdManager.saveVertexId(vertex, vertex.getId());
}
addVertexToBuilder(builder, dag, rootLoopVertex, loopVerticesOfSameLoop);
} else { // following loops
final LoopVertex finalRootLoopVertex = rootLoopVertex;
// Add the loop to the list
loopVerticesOfSameLoop.putIfAbsent(loopVertex, finalRootLoopVertex);
finalRootLoopVertex.increaseMaxNumberOfIterations();
// Zip current vertices together. We rely on the fact that getTopologicalSort() brings consistent results.
final Iterator<IRVertex> verticesOfRootLoopVertex =
finalRootLoopVertex.getDAG().getTopologicalSort().iterator();
final Iterator<IRVertex> verticesOfCurrentLoopVertex = loopVertex.getDAG().getTopologicalSort().iterator();
// Map of (RolledVertex --> (Root)Vertex)
final HashMap<IRVertex, IRVertex> equivalentVertices = equivalentVerticesOfLoops.get(finalRootLoopVertex);
while (verticesOfRootLoopVertex.hasNext() && verticesOfCurrentLoopVertex.hasNext()) {
final IRVertex vertexOfCurrentLoopVertex = verticesOfCurrentLoopVertex.next();
final IRVertex vertexOfRootLoopVertex = verticesOfRootLoopVertex.next();
equivalentVertices.put(vertexOfCurrentLoopVertex, vertexOfRootLoopVertex);
IdManager.saveVertexId(vertexOfRootLoopVertex, vertexOfCurrentLoopVertex.getId());
}
// reset non iterative incoming edges.
finalRootLoopVertex.getNonIterativeIncomingEdges().clear();
finalRootLoopVertex.getIterativeIncomingEdges().clear();
// incoming edges to the DAG.
loopVertex.getDagIncomingEdges().forEach((dstVertex, edges) -> edges.forEach(edge -> {
final IRVertex srcVertex = edge.getSrc();
final IRVertex equivalentDstVertex = equivalentVertices.get(dstVertex); // find the (Root)Vertex
if (equivalentVertices.containsKey(srcVertex)) {
// src is from the previous loop. vertex in previous loop -> DAG.
final IRVertex equivalentSrcVertex = equivalentVertices.get(srcVertex);
// add the new IREdge to the iterative incoming edges list.
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
equivalentSrcVertex, equivalentDstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addIterativeIncomingEdge(newIrEdge);
} else {
// src is from outside the previous loop. vertex outside previous loop -> DAG.
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
srcVertex, equivalentDstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addNonIterativeIncomingEdge(newIrEdge);
}
}));
// Overwrite the DAG outgoing edges
finalRootLoopVertex.getDagOutgoingEdges().clear();
loopVertex.getDagOutgoingEdges().forEach((srcVertex, edges) -> edges.forEach(edge -> {
final IRVertex dstVertex = edge.getDst();
final IRVertex equivalentSrcVertex = equivalentVertices.get(srcVertex);
final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
equivalentSrcVertex, dstVertex);
edge.copyExecutionPropertiesTo(newIrEdge);
finalRootLoopVertex.addDagOutgoingEdge(newIrEdge);
finalRootLoopVertex.mapEdgeWithLoop(loopVertex.getEdgeWithLoop(edge), newIrEdge);
}));
}
} else {
throw new UnsupportedOperationException("Unknown vertex type: " + irVertex);
}
}
return builder.build();
}