public IRDAG apply()

in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SamplingSkewReshapingPass.java [74:126]


  public IRDAG apply(final IRDAG dag) {
    dag.topologicalDo(v -> {
      for (final IREdge e : dag.getIncomingEdgesOf(v)) {
        if (CommunicationPatternProperty.Value.SHUFFLE.equals(
          e.getPropertyValue(CommunicationPatternProperty.class).get())) {
          // Compute the partition and its source vertices
          final IRVertex shuffleWriter = e.getSrc();
          final Set<IRVertex> partitionAll = recursivelyBuildPartition(shuffleWriter, dag);
          final Set<IRVertex> partitionSources = partitionAll.stream().filter(vertexInPartition ->
            !dag.getIncomingEdgesOf(vertexInPartition).stream()
              .map(Edge::getSrc)
              .anyMatch(partitionAll::contains)
          ).collect(Collectors.toSet());

          // Check if the partition is a sink, in which case we do not create sampling vertices
          final boolean isSinkPartition = partitionAll.stream()
            .flatMap(vertexInPartition -> dag.getOutgoingEdgesOf(vertexInPartition).stream())
            .map(Edge::getDst)
            .allMatch(partitionAll::contains);
          if (isSinkPartition) {
            break;
          }

          // Insert sampling vertices.
          final Set<SamplingVertex> samplingVertices = partitionAll
            .stream()
            .map(vertexInPartition -> new SamplingVertex(vertexInPartition, SAMPLE_RATE))
            .collect(Collectors.toSet());
          dag.insert(samplingVertices, partitionSources);

          // Insert the message vertex.
          // We first obtain a clonedShuffleEdge to analyze the data statistics of the shuffle outputs of
          // the sampling vertex right before shuffle.
          final SamplingVertex rightBeforeShuffle = samplingVertices.stream()
            .filter(sv -> sv.getOriginalVertexId().equals(e.getSrc().getId()))
            .findFirst()
            .orElseThrow(IllegalStateException::new);
          final IREdge clonedShuffleEdge = rightBeforeShuffle.getCloneOfOriginalEdge(e);

          final KeyExtractor keyExtractor = e.getPropertyValue(KeyExtractorProperty.class).get();
          dag.insert(
            new MessageGeneratorVertex<>(SkewHandlingUtil.getMessageGenerator(keyExtractor)),
            new MessageAggregatorVertex(HashMap::new, SkewHandlingUtil.getMessageAggregator()),
            SkewHandlingUtil.getEncoder(e),
            SkewHandlingUtil.getDecoder(e),
            new HashSet<>(Arrays.asList(clonedShuffleEdge)), // this works although the clone is not in the dag
            new HashSet<>(Arrays.asList(e))); // we want to optimize the original edge, not the clone
        }
      }
    });

    return dag;
  }