in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingTaskSizingPass.java [70:148]
public IRDAG apply(final IRDAG dag) {
/* Step 1. check DTS launch by job size */
boolean enableDynamicTaskSizing = isDTSEnabledByJobSize(dag);
if (!enableDynamicTaskSizing) {
return dag;
} else {
dag.topologicalDo(v -> v.setProperty(EnableDynamicTaskSizingProperty.of(enableDynamicTaskSizing)));
}
final int partitionerProperty = getPartitionerPropertyByJobSize(dag);
/* Step 2-1. Group vertices by stage using stage merging logic */
final Map<IRVertex, Integer> vertexToStageId = stagePartitioner.apply(dag);
final Map<Integer, Set<IRVertex>> stageIdToStageVertices = new HashMap<>();
vertexToStageId.forEach((vertex, stageId) -> {
if (!stageIdToStageVertices.containsKey(stageId)) {
stageIdToStageVertices.put(stageId, new HashSet<>());
}
stageIdToStageVertices.get(stageId).add(vertex);
});
/* Step 2-2. Mark stages to insert splitter vertex and get target edges of DTS */
Set<Integer> stageIdsToInsertSplitter = new HashSet<>();
Set<IREdge> shuffleEdgesForDTS = new HashSet<>();
dag.topologicalDo(v -> {
for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
if (isAppropriateForInsertingSplitterVertex(dag, v, edge, vertexToStageId, stageIdToStageVertices)) {
stageIdsToInsertSplitter.add(vertexToStageId.get(v));
shuffleEdgesForDTS.add(edge);
}
}
});
/* Step 2-3. Change partitioner property for DTS target edges */
dag.topologicalDo(v -> {
for (final IREdge edge : dag.getIncomingEdgesOf(v)) {
if (shuffleEdgesForDTS.contains(edge)) {
shuffleEdgesForDTS.remove(edge);
edge.setProperty(PartitionerProperty.of(PartitionerProperty.Type.HASH, partitionerProperty));
shuffleEdgesForDTS.add(edge);
}
}
});
/* Step 3. Insert Splitter Vertex */
List<IRVertex> reverseTopologicalOrder = dag.getTopologicalSort();
Collections.reverse(reverseTopologicalOrder);
for (IRVertex v : reverseTopologicalOrder) {
for (final IREdge edge : dag.getOutgoingEdgesOf(v)) {
if (shuffleEdgesForDTS.contains(edge)) {
// edge is the incoming edge of observing stage, v is the last vertex of previous stage
Set<IRVertex> stageVertices = stageIdToStageVertices.get(vertexToStageId.get(edge.getDst()));
Set<IRVertex> verticesWithStageOutgoingEdges = new HashSet<>();
for (IRVertex v2 : stageVertices) {
Set<IRVertex> nextVertices = dag.getOutgoingEdgesOf(v2).stream().map(Edge::getDst)
.collect(Collectors.toSet());
for (IRVertex v3 : nextVertices) {
if (!stageVertices.contains(v3)) {
verticesWithStageOutgoingEdges.add(v2);
}
}
}
Set<IRVertex> stageEndingVertices = stageVertices.stream()
.filter(stageVertex -> dag.getOutgoingEdgesOf(stageVertex).isEmpty()
|| !dag.getOutgoingEdgesOf(stageVertex).stream().map(Edge::getDst).anyMatch(stageVertices::contains))
.collect(Collectors.toSet());
final boolean isSourcePartition = stageVertices.stream()
.flatMap(vertexInPartition -> dag.getIncomingEdgesOf(vertexInPartition).stream())
.map(Edge::getSrc)
.allMatch(stageVertices::contains);
if (isSourcePartition) {
break;
}
insertSplitterVertex(dag, stageVertices, Collections.singleton(edge.getDst()),
verticesWithStageOutgoingEdges, stageEndingVertices, partitionerProperty);
}
}
}
return dag;
}