in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java [253:314]
DAG<IRVertex, IREdge> recursivelyOptimize(final DAG<IRVertex, IREdge> dag) {
final List<LoopVertex> loopVertices = new ArrayList<>();
final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
// Refactor those with same data scan / operation, without dependencies in the loop.
loopVertices.forEach(loopVertex -> {
final List<Map.Entry<IRVertex, Set<IREdge>>> candidates = loopVertex.getNonIterativeIncomingEdges().entrySet()
.stream().filter(entry ->
loopVertex.getDAG().getIncomingEdgesOf(entry.getKey()).isEmpty() // no internal inEdges
// no external inEdges
&& loopVertex.getIterativeIncomingEdges().getOrDefault(entry.getKey(), new HashSet<>()).isEmpty())
.collect(Collectors.toList());
candidates.forEach(candidate -> {
// add refactored vertex to builder.
builder.addVertex(candidate.getKey());
// connect incoming edges.
candidate.getValue().forEach(builder::connectVertices);
// connect outgoing edges.
loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addDagIncomingEdge);
loopVertex.getDAG().getOutgoingEdgesOf(candidate.getKey()).forEach(loopVertex::addNonIterativeIncomingEdge);
// modify incoming edges of loopVertex.
final List<IREdge> edgesToRemove = new ArrayList<>();
final List<IREdge> edgesToAdd = new ArrayList<>();
inEdges.getOrDefault(loopVertex, new ArrayList<>()).stream().filter(e ->
// filter edges that have their sources as the refactored vertices.
candidate.getValue().stream().map(IREdge::getSrc).anyMatch(edgeSrc -> edgeSrc.equals(e.getSrc())))
.forEach(edge -> {
edgesToRemove.add(edge);
final IREdge newEdge = new IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
candidate.getKey(), edge.getDst());
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
edgesToAdd.add(newEdge);
});
final List<IREdge> listToModify = inEdges.getOrDefault(loopVertex, new ArrayList<>());
listToModify.removeAll(edgesToRemove);
listToModify.addAll(edgesToAdd);
// clear garbage.
loopVertex.getBuilder().removeVertex(candidate.getKey());
loopVertex.getDagIncomingEdges().remove(candidate.getKey());
loopVertex.getNonIterativeIncomingEdges().remove(candidate.getKey());
});
});
// Add LoopVertices.
loopVertices.forEach(loopVertex -> {
builder.addVertex(loopVertex);
inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(builder::connectVertices);
});
final DAG<IRVertex, IREdge> newDag = builder.build();
if (dag.getVertices().size() == newDag.getVertices().size()) {
return newDag;
} else {
return recursivelyOptimize(newDag);
}
}