DAG recursivelyOptimize()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java [253:314]


    DAG<IRVertex, IREdge> recursivelyOptimize(final DAG<IRVertex, IREdge> dag) {
      final List<LoopVertex> loopVertices = new ArrayList<>();
      final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
      final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
      final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();

      collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);

      // Refactor those with same data scan / operation, without dependencies in the loop.
      loopVertices.forEach(loopVertex -> {
        final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
          .stream().filter(entry ->
            loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).isEmpty() // no internal inEdges
              // no external inEdges
              && loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).isEmpty())
          .collect(Collectors.toList());
        candidates.forEach(candidate -> {
          // add refactored vertex to builder.
          builder.addVertex(candidate.getKey());
          // connect incoming edges.
          candidate.getValue().forEach(builder::connectVertices);
          // connect outgoing edges.
          loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
          loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
          // modify incoming edges of loopVertex.
          final List<IREdge> edgesToRemove = new ArrayList<>();
          final List<IREdge> edgesToAdd = new ArrayList<>();
          inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
            // filter edges that have their sources as the refactored vertices.
            candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
            .forEach(edge -> {
              edgesToRemove.add(edge);
              final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
                candidate.getKey(), edge.getDst());
              newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
              newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
              edgesToAdd.add(newEdge);
            });
          final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
          listToModify.removeAll(edgesToRemove);
          listToModify.addAll(edgesToAdd);
          // clear garbage.
          loopVertex.getBuilder().removeVertex(candidate.getKey());
          loopVertex.getDagIncomingEdges().remove(candidate.getKey());
          loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
        });
      });

      // Add LoopVertices.
      loopVertices.forEach(loopVertex -> {
        builder.addVertex(loopVertex);
        inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
        outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
      });

      final DAG<IRVertex, IREdge> newDag = builder.build();
      if (dag.getVertices().size() == newDag.getVertices().size()) {
        return newDag;
      } else {
        return recursivelyOptimize(newDag);
      }
    }