in compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/ResourceSitePass.java [112:159]
private static void assignNodeShares(
final IRDAG dag,
final BandwidthSpecification bandwidthSpecification) {
dag.topologicalDo(irVertex -> {
final Collection<IREdge> inEdges = dag.getIncomingEdgesOf(irVertex);
final int parallelism = irVertex.getPropertyValue(ParallelismProperty.class)
.orElseThrow(() -> new RuntimeException("Parallelism property required"));
if (inEdges.size() == 0) {
// This vertex is root vertex.
// Fall back to setting even distribution
irVertex.setProperty(ResourceSiteProperty.of(EMPTY_MAP));
} else if (isOneToOneEdge(inEdges)) {
final Optional<HashMap<String, Integer>> property = inEdges.iterator().next().getSrc()
.getPropertyValue(ResourceSiteProperty.class);
irVertex.setProperty(ResourceSiteProperty.of(property.get()));
} else {
// This IRVertex has shuffle inEdge(s), or has multiple inEdges.
final Map<String, Integer> parentLocationShares = new HashMap<>();
for (final IREdge edgeToIRVertex : dag.getIncomingEdgesOf(irVertex)) {
final IRVertex parentVertex = edgeToIRVertex.getSrc();
final Map<String, Integer> parentShares = parentVertex.getPropertyValue(ResourceSiteProperty.class).get();
final int parentParallelism = parentVertex.getPropertyValue(ParallelismProperty.class)
.orElseThrow(() -> new RuntimeException("Parallelism property required"));
final Map<String, Integer> shares = parentShares.isEmpty() ? getEvenShares(bandwidthSpecification.getNodes(),
parentParallelism) : parentShares;
for (final Map.Entry<String, Integer> element : shares.entrySet()) {
parentLocationShares.putIfAbsent(element.getKey(), 0);
parentLocationShares.put(element.getKey(),
element.getValue() + parentLocationShares.get(element.getKey()));
}
}
final double[] ratios = optimize(bandwidthSpecification, parentLocationShares);
final HashMap<String, Integer> shares = new HashMap<>();
for (int i = 0; i < bandwidthSpecification.getNodes().size(); i++) {
shares.put(bandwidthSpecification.getNodes().get(i), (int) (ratios[i] * parallelism));
}
int remainder = parallelism - shares.values().stream().mapToInt(i -> i).sum();
for (final String nodeName : shares.keySet()) {
if (remainder == 0) {
break;
}
shares.put(nodeName, shares.get(nodeName) + 1);
remainder--;
}
irVertex.setProperty(ResourceSiteProperty.of(shares));
}
});
}