in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/CommonSubexpressionEliminationPass.java [54:134]
public IRDAG apply(final IRDAG inputDAG) {
// find and collect vertices with equivalent transforms
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
final Map<Transform, List<OperatorVertex>> operatorVerticesToBeMerged = new HashMap<>();
final Map<OperatorVertex, Set<IREdge>> inEdges = new HashMap<>();
final Map<OperatorVertex, Set<IREdge>> outEdges = new HashMap<>();
inputDAG.reshapeUnsafely(dag -> {
dag.topologicalDo(irVertex -> {
if (irVertex instanceof OperatorVertex) {
final OperatorVertex operatorVertex = (OperatorVertex) irVertex;
operatorVerticesToBeMerged.putIfAbsent(operatorVertex.getTransform(), new ArrayList<>());
operatorVerticesToBeMerged.get(operatorVertex.getTransform()).add(operatorVertex);
dag.getIncomingEdgesOf(operatorVertex).forEach(irEdge -> {
inEdges.putIfAbsent(operatorVertex, new HashSet<>());
inEdges.get(operatorVertex).add(irEdge);
if (irEdge.getSrc() instanceof OperatorVertex) {
final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
outEdges.putIfAbsent(source, new HashSet<>());
outEdges.get(source).add(irEdge);
}
});
} else {
builder.addVertex(irVertex, dag);
dag.getIncomingEdgesOf(irVertex).forEach(irEdge -> {
if (irEdge.getSrc() instanceof OperatorVertex) {
final OperatorVertex source = (OperatorVertex) irEdge.getSrc();
outEdges.putIfAbsent(source, new HashSet<>());
outEdges.get(source).add(irEdge);
} else {
builder.connectVertices(irEdge);
}
});
}
});
// merge them if they are not dependent on each other, and add IRVertices to the builder.
operatorVerticesToBeMerged.forEach(((transform, operatorVertices) -> {
final Map<Set<IRVertex>, List<OperatorVertex>> verticesToBeMergedWithIdenticalSources = new HashMap<>();
operatorVertices.forEach(operatorVertex -> {
// compare if incoming vertices are identical.
final Set<IRVertex> incomingVertices = dag.getIncomingEdgesOf(operatorVertex).stream().map(IREdge::getSrc)
.collect(Collectors.toSet());
if (verticesToBeMergedWithIdenticalSources.keySet().stream()
.anyMatch(lst -> lst.containsAll(incomingVertices) && incomingVertices.containsAll(lst))) {
final Set<IRVertex> foundKey = verticesToBeMergedWithIdenticalSources.keySet().stream()
.filter(vs -> vs.containsAll(incomingVertices) && incomingVertices.containsAll(vs))
.findFirst().get();
verticesToBeMergedWithIdenticalSources.get(foundKey).add(operatorVertex);
} else {
verticesToBeMergedWithIdenticalSources.putIfAbsent(incomingVertices, new ArrayList<>());
verticesToBeMergedWithIdenticalSources.get(incomingVertices).add(operatorVertex);
}
});
verticesToBeMergedWithIdenticalSources.values().forEach(ovs ->
mergeAndAddToBuilder(ovs, builder, dag, inEdges, outEdges));
}));
// process IREdges
operatorVerticesToBeMerged.values().forEach(operatorVertices ->
operatorVertices.forEach(operatorVertex -> {
inEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
if (builder.contains(operatorVertex) && builder.contains(e.getSrc())) {
builder.connectVertices(e);
}
});
outEdges.getOrDefault(operatorVertex, new HashSet<>()).forEach(e -> {
if (builder.contains(operatorVertex) && builder.contains(e.getDst())) {
builder.connectVertices(e);
}
});
}));
return builder.build();
});
return inputDAG;
}