in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java [74:126]
public IRDAG apply(final IRDAG dag) {
dag.topologicalDo(v -> {
for (final IREdge e : dag.getIncomingEdgesOf(v)) {
if (CommunicationPatternProperty.Value.SHUFFLE.equals(
e.getPropertyValue(CommunicationPatternProperty.class).get())) {
// Compute the partition and its source vertices
final IRVertex shuffleWriter = e.getSrc();
final Set<IRVertex> partitionAll = recursivelyBuildPartition(shuffleWriter, dag);
final Set<IRVertex> partitionSources = partitionAll.stream().filter(vertexInPartition ->
!dag.getIncomingEdgesOf(vertexInPartition).stream()
.map(Edge::getSrc)
.anyMatch(partitionAll::contains)
).collect(Collectors.toSet());
// Check if the partition is a sink, in which case we do not create sampling vertices
final boolean isSinkPartition = partitionAll.stream()
.flatMap(vertexInPartition -> dag.getOutgoingEdgesOf(vertexInPartition).stream())
.map(Edge::getDst)
.allMatch(partitionAll::contains);
if (isSinkPartition) {
break;
}
// Insert sampling vertices.
final Set<SamplingVertex> samplingVertices = partitionAll
.stream()
.map(vertexInPartition -> new SamplingVertex(vertexInPartition, SAMPLE_RATE))
.collect(Collectors.toSet());
dag.insert(samplingVertices, partitionSources);
// Insert the message vertex.
// We first obtain a clonedShuffleEdge to analyze the data statistics of the shuffle outputs of
// the sampling vertex right before shuffle.
final SamplingVertex rightBeforeShuffle = samplingVertices.stream()
.filter(sv -> sv.getOriginalVertexId().equals(e.getSrc().getId()))
.findFirst()
.orElseThrow(IllegalStateException::new);
final IREdge clonedShuffleEdge = rightBeforeShuffle.getCloneOfOriginalEdge(e);
final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
dag.insert(
new MessageGeneratorVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
new MessageAggregatorVertex(HashMap::new, SkewHandlingUtil.getMessageAggregator()),
SkewHandlingUtil.getEncoder(e),
SkewHandlingUtil.getDecoder(e),
new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
new HashSet<>(Arrays.asList(e))); // we want to optimize the original edge, not the clone
}
}
});
return dag;
}