in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopOptimizations.java [116:206]
public IRDAG apply(final IRDAG inputDAG) {
inputDAG.reshapeUnsafely(dag -> {
final List<LoopVertex> loopVertices = new ArrayList<>();
final Map<LoopVertex, List<IREdge>> inEdges = new HashMap<>();
final Map<LoopVertex, List<IREdge>> outEdges = new HashMap<>();
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
collectLoopVertices(dag, loopVertices, inEdges, outEdges, builder);
// Collect and group those with same termination condition.
final Set<Set<LoopVertex>> setOfLoopsToBeFused = new HashSet<>();
loopVertices.forEach(loopVertex -> {
// We want loopVertices that are not dependent on each other
// or the list that is potentially going to be merged.
final List<LoopVertex> independentLoops = loopVertices.stream().filter(loop ->
setOfLoopsToBeFused.stream().anyMatch(list -> list.contains(loop))
? setOfLoopsToBeFused.stream().filter(list -> list.contains(loop))
.findFirst()
.map(list -> list.stream().noneMatch(loopV -> dag.pathExistsBetween(loopV, loopVertex)))
.orElse(false)
: !dag.pathExistsBetween(loop, loopVertex)).collect(Collectors.toList());
// Find loops to be fused together.
final Set<LoopVertex> loopsToBeFused = new HashSet<>();
loopsToBeFused.add(loopVertex);
independentLoops.forEach(independentLoop -> {
// add them to the list if those independent loops have equal termination conditions.
if (loopVertex.terminationConditionEquals(independentLoop)) {
loopsToBeFused.add(independentLoop);
}
});
// add this information to the setOfLoopsToBeFused set.
final Optional<Set<LoopVertex>> listToAddVerticesTo = setOfLoopsToBeFused.stream()
.filter(list -> list.stream().anyMatch(loopsToBeFused::contains)).findFirst();
if (listToAddVerticesTo.isPresent()) {
listToAddVerticesTo.get().addAll(loopsToBeFused);
} else {
setOfLoopsToBeFused.add(loopsToBeFused);
}
});
// merge and add to builder.
setOfLoopsToBeFused.forEach(loops -> {
if (loops.size() > 1) {
final LoopVertex newLoopVertex = mergeLoopVertices(loops);
builder.addVertex(newLoopVertex, dag);
loops.forEach(loopVertex -> {
// inEdges.
inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
if (builder.contains(irEdge.getSrc())) {
final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
.get(), irEdge.getSrc(), newLoopVertex);
irEdge.copyExecutionPropertiesTo(newIREdge);
builder.connectVertices(newIREdge);
}
});
// outEdges.
outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(irEdge -> {
if (builder.contains(irEdge.getDst())) {
final IREdge newIREdge = new IREdge(irEdge.getPropertyValue(CommunicationPatternProperty.class)
.get(), newLoopVertex, irEdge.getDst());
irEdge.copyExecutionPropertiesTo(newIREdge);
builder.connectVertices(newIREdge);
}
});
});
} else {
loops.forEach(loopVertex -> {
builder.addVertex(loopVertex);
// inEdges.
inEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
if (builder.contains(edge.getSrc())) {
builder.connectVertices(edge);
}
});
// outEdges.
outEdges.getOrDefault(loopVertex, new ArrayList<>()).forEach(edge -> {
if (builder.contains(edge.getDst())) {
builder.connectVertices(edge);
}
});
});
}
});
return builder.build();
});
return inputDAG;
}