private DAG loopRolling()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPass.java [190:287]


  private DAG<IRVertex, IREdge> loopRolling(final DAG<IRVertex, IREdge> dag) {
    final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
    // Map for LoopVertex --> RootLoopVertex
    final HashMap<LoopVertex, LoopVertex> loopVerticesOfSameLoop = new HashMap<>();
    // RootLoopVertex --> Map of (RolledVertex --> (Root)Vertex)
    final HashMap<LoopVertex, HashMap<IRVertex, IRVertex>> equivalentVerticesOfLoops = new HashMap<>();
    // The RootLoopVertex that we're processing now.
    LoopVertex rootLoopVertex = null;

    // observe the DAG in a topological order.
    for (IRVertex irVertex : dag.getTopologicalSort()) {
      if (irVertex instanceof SourceVertex) { // source vertex
        builder.addVertex(irVertex, dag);

      } else if (irVertex instanceof OperatorVertex) { // operator vertex
        addVertexToBuilder(builder, dag, irVertex, loopVerticesOfSameLoop);

      } else if (irVertex instanceof LoopVertex) { // loop vertex: we roll them if it is not root
        final LoopVertex loopVertex = (LoopVertex) irVertex;

        if (rootLoopVertex == null || !loopVertex.getName().contains(rootLoopVertex.getName())) { // initial root loop
          rootLoopVertex = loopVertex;
          loopVerticesOfSameLoop.putIfAbsent(rootLoopVertex, rootLoopVertex);
          equivalentVerticesOfLoops.putIfAbsent(rootLoopVertex, new HashMap<>());
          // Add the initial vertices
          for (IRVertex vertex : rootLoopVertex.getDAG().getTopologicalSort()) {
            equivalentVerticesOfLoops.get(rootLoopVertex).putIfAbsent(vertex, vertex);
            IdManager.saveVertexId(vertex, vertex.getId());
          }
          addVertexToBuilder(builder, dag, rootLoopVertex, loopVerticesOfSameLoop);
        } else { // following loops
          final LoopVertex finalRootLoopVertex = rootLoopVertex;

          // Add the loop to the list
          loopVerticesOfSameLoop.putIfAbsent(loopVertex, finalRootLoopVertex);
          finalRootLoopVertex.increaseMaxNumberOfIterations();

          // Zip current vertices together. We rely on the fact that getTopologicalSort() brings consistent results.
          final Iterator<IRVertex> verticesOfRootLoopVertex =
            finalRootLoopVertex.getDAG().getTopologicalSort().iterator();
          final Iterator<IRVertex> verticesOfCurrentLoopVertex = loopVertex.getDAG().getTopologicalSort().iterator();
          // Map of (RolledVertex --> (Root)Vertex)
          final HashMap<IRVertex, IRVertex> equivalentVertices = equivalentVerticesOfLoops.get(finalRootLoopVertex);
          while (verticesOfRootLoopVertex.hasNext() && verticesOfCurrentLoopVertex.hasNext()) {
            final IRVertex vertexOfCurrentLoopVertex = verticesOfCurrentLoopVertex.next();
            final IRVertex vertexOfRootLoopVertex = verticesOfRootLoopVertex.next();
            equivalentVertices.put(vertexOfCurrentLoopVertex, vertexOfRootLoopVertex);
            IdManager.saveVertexId(vertexOfRootLoopVertex, vertexOfCurrentLoopVertex.getId());
          }

          // reset non iterative incoming edges.
          finalRootLoopVertex.getNonIterativeIncomingEdges().clear();
          finalRootLoopVertex.getIterativeIncomingEdges().clear();

          // incoming edges to the DAG.
          loopVertex.getDagIncomingEdges().forEach((dstVertex, edges) -> edges.forEach(edge -> {
            final IRVertex srcVertex = edge.getSrc();
            final IRVertex equivalentDstVertex = equivalentVertices.get(dstVertex);  // find the (Root)Vertex

            if (equivalentVertices.containsKey(srcVertex)) {
              // src is from the previous loop. vertex in previous loop -> DAG.
              final IRVertex equivalentSrcVertex = equivalentVertices.get(srcVertex);

              // add the new IREdge to the iterative incoming edges list.
              final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
                equivalentSrcVertex, equivalentDstVertex);
              edge.copyExecutionPropertiesTo(newIrEdge);
              finalRootLoopVertex.addIterativeIncomingEdge(newIrEdge);
            } else {
              // src is from outside the previous loop. vertex outside previous loop -> DAG.
              final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
                srcVertex, equivalentDstVertex);
              edge.copyExecutionPropertiesTo(newIrEdge);
              finalRootLoopVertex.addNonIterativeIncomingEdge(newIrEdge);
            }
          }));

          // Overwrite the DAG outgoing edges
          finalRootLoopVertex.getDagOutgoingEdges().clear();
          loopVertex.getDagOutgoingEdges().forEach((srcVertex, edges) -> edges.forEach(edge -> {
            final IRVertex dstVertex = edge.getDst();
            final IRVertex equivalentSrcVertex = equivalentVertices.get(srcVertex);

            final IREdge newIrEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
              equivalentSrcVertex, dstVertex);
            edge.copyExecutionPropertiesTo(newIrEdge);
            finalRootLoopVertex.addDagOutgoingEdge(newIrEdge);
            finalRootLoopVertex.mapEdgeWithLoop(loopVertex.getEdgeWithLoop(edge), newIrEdge);
          }));
        }

      } else {
        throw new UnsupportedOperationException("Unknown vertex type: " + irVertex);
      }
    }

    return builder.build();
  }