public IRDAG apply()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/CommonSubexpressionEliminationPass.java [54:134]


  public IRDAG apply(final IRDAG inputDAG) {
    // find and collect vertices with equivalent transforms
    final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
    final Map<Transform, List<OperatorVertex>> operatorVerticesToBeMerged = new HashMap<>();
    final Map<OperatorVertex, Set<IREdge>> inEdges = new HashMap<>();
    final Map<OperatorVertex, Set<IREdge>> outEdges = new HashMap<>();

    inputDAG.reshapeUnsafely(dag -> {
      dag.topologicalDo(irVertex -> {
        if (irVertex instanceof OperatorVertex) {
          final OperatorVertex operatorVertex = (OperatorVertex) irVertex;
          operatorVerticesToBeMerged.putIfAbsent(operatorVertex.getTransform(), new ArrayList<>());
          operatorVerticesToBeMerged.get(operatorVertex.getTransform()).add(operatorVertex);

          dag.getIncomingEdgesOf(operatorVertex).forEach(irEdge -> {
            inEdges.putIfAbsent(operatorVertex, new HashSet<>());
            inEdges.get(operatorVertex).add(irEdge);
            if (irEdge.getSrc() instanceof OperatorVertex) {
              final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
              outEdges.putIfAbsent(source, new HashSet<>());
              outEdges.get(source).add(irEdge);
            }
          });
        } else {
          builder.addVertex(irVertex, dag);
          dag.getIncomingEdgesOf(irVertex).forEach(irEdge -> {
            if (irEdge.getSrc() instanceof OperatorVertex) {
              final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
              outEdges.putIfAbsent(source, new HashSet<>());
              outEdges.get(source).add(irEdge);
            } else {
              builder.connectVertices(irEdge);
            }
          });
        }
      });

      // merge them if they are not dependent on each other, and add IRVertices to the builder.
      operatorVerticesToBeMerged.forEach(((transform, operatorVertices) -> {
        final Map<Set<IRVertex>, List<OperatorVertex>> verticesToBeMergedWithIdenticalSources = new HashMap<>();

        operatorVertices.forEach(operatorVertex -> {
          // compare if incoming vertices are identical.
          final Set<IRVertex> incomingVertices = dag.getIncomingEdgesOf(operatorVertex).stream().map(IREdge::getSrc)
            .collect(Collectors.toSet());
          if (verticesToBeMergedWithIdenticalSources.keySet().stream()
            .anyMatch(lst -> lst.containsAll(incomingVertices) && incomingVertices.containsAll(lst))) {
            final Set<IRVertex> foundKey = verticesToBeMergedWithIdenticalSources.keySet().stream()
              .filter(vs -> vs.containsAll(incomingVertices) && incomingVertices.containsAll(vs))
              .findFirst().get();
            verticesToBeMergedWithIdenticalSources.get(foundKey).add(operatorVertex);
          } else {
            verticesToBeMergedWithIdenticalSources.putIfAbsent(incomingVertices, new ArrayList<>());
            verticesToBeMergedWithIdenticalSources.get(incomingVertices).add(operatorVertex);
          }
        });

        verticesToBeMergedWithIdenticalSources.values().forEach(ovs ->
          mergeAndAddToBuilder(ovs, builder, dag, inEdges, outEdges));
      }));

      // process IREdges
      operatorVerticesToBeMerged.values().forEach(operatorVertices ->
        operatorVertices.forEach(operatorVertex -> {
          inEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
            if (builder.contains(operatorVertex) && builder.contains(e.getSrc())) {
              builder.connectVertices(e);
            }
          });
          outEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
            if (builder.contains(operatorVertex) && builder.contains(e.getDst())) {
              builder.connectVertices(e);
            }
          });
        }));

      return builder.build();
    });

    return inputDAG;
  }