public IRDAG apply()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java [116:206]


    public IRDAG apply(final IRDAG inputDAG) {
      inputDAG.reshapeUnsafely(dag -> {
        final List<LoopVertex> loopVertices = new ArrayList<>();
        final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
        final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
        final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();

        collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);

        // Collect and group those with same termination condition.
        final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
        loopVertices.forEach(loopVertex -> {
          // We want loopVertices that are not dependent on each other
          // or the list that is potentially going to be merged.
          final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
            setOfLoopsToBeFused.stream().anyMatch(list -> list.contains(loop))
              ? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
              .findFirst()
              .map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
              .orElse(false)
              : !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());

          // Find loops to be fused together.
          final Set<LoopVertex> loopsToBeFused = new HashSet<>();
          loopsToBeFused.add(loopVertex);
          independentLoops.forEach(independentLoop -> {
            // add them to the list if those independent loops have equal termination conditions.
            if (loopVertex.terminationConditionEquals(independentLoop)) {
              loopsToBeFused.add(independentLoop);
            }
          });

          // add this information to the setOfLoopsToBeFused set.
          final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
            .filter(list -> list.stream().anyMatch(loopsToBeFused::contains)).findFirst();
          if (listToAddVerticesTo.isPresent()) {
            listToAddVerticesTo.get().addAll(loopsToBeFused);
          } else {
            setOfLoopsToBeFused.add(loopsToBeFused);
          }
        });

        // merge and add to builder.
        setOfLoopsToBeFused.forEach(loops -> {
          if (loops.size() > 1) {
            final LoopVertex newLoopVertex = mergeLoopVertices(loops);
            builder.addVertex(newLoopVertex, dag);
            loops.forEach(loopVertex -> {
              // inEdges.
              inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
                if (builder.contains(irEdge.getSrc())) {
                  final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
                    .get(), irEdge.getSrc(), newLoopVertex);
                  irEdge.copyExecutionPropertiesTo(newIREdge);
                  builder.connectVertices(newIREdge);
                }
              });
              // outEdges.
              outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
                if (builder.contains(irEdge.getDst())) {
                  final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
                    .get(), newLoopVertex, irEdge.getDst());
                  irEdge.copyExecutionPropertiesTo(newIREdge);
                  builder.connectVertices(newIREdge);
                }
              });
            });
          } else {
            loops.forEach(loopVertex -> {
              builder.addVertex(loopVertex);
              // inEdges.
              inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
                if (builder.contains(edge.getSrc())) {
                  builder.connectVertices(edge);
                }
              });
              // outEdges.
              outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
                if (builder.contains(edge.getDst())) {
                  builder.connectVertices(edge);
                }
              });
            });
          }
        });

        return builder.build();
      });

      return inputDAG;
    }