in plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java [54:110]
public void dispatchTask(ActionListener<DiscoveryNode> listener) {
// todo: add ML node type setting check
// DiscoveryNode[] mlNodes = getEligibleMLNodes();
DiscoveryNode[] mlNodes = getEligibleDataNodes();
MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName()));
client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> {
// Check JVM pressure
List<MLStatsNodeResponse> candidateNodeResponse = mlStatsResponse
.getNodes()
.stream()
.filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD)
.collect(Collectors.toList());
if (candidateNodeResponse.size() == 0) {
String errorMessage = "All nodes' memory usage exceeds limitation "
+ DEFAULT_JVM_HEAP_USAGE_THRESHOLD
+ ". No eligible node available to run ml jobs ";
log.warn(errorMessage);
listener.onFailure(new LimitExceededException(errorMessage));
return;
}
// Check # of executing ML task
candidateNodeResponse = candidateNodeResponse
.stream()
.filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()) < maxMLBatchTaskPerNode)
.collect(Collectors.toList());
if (candidateNodeResponse.size() == 0) {
String errorMessage = "All nodes' executing ML task count reach limitation.";
log.warn(errorMessage);
listener.onFailure(new LimitExceededException(errorMessage));
return;
}
// sort nodes by JVM usage percentage and # of executing ML task
Optional<MLStatsNodeResponse> targetNode = candidateNodeResponse
.stream()
.sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> {
int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()))
.compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()));
if (result == 0) {
// if multiple nodes have same running task count, choose the one with least
// JVM heap usage.
return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName()))
.compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName()));
}
return result;
})
.findFirst();
listener.onResponse(targetNode.get().getNode());
}, exception -> {
log.error("Failed to get node's task stats", exception);
listener.onFailure(exception);
}));
}