in frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java [41:145]
public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {
CompletionService<NodeOutput> executorCompletionService = null;
if (topoSortedList == null) {
ExecutorService executorService = Executors.newFixedThreadPool(4);
executorCompletionService = new ExecutorCompletionService<>(executorService);
}
Map<String, Integer> inDegreeMap = this.dag.getInDegreeMap();
Set<String> zeroInDegree = dag.getStartNodeNames();
Set<String> executing = new HashSet<>();
if (topoSortedList == null) {
for (String s : zeroInDegree) {
RequestInput newInput = new RequestInput(UUID.randomUUID().toString());
newInput.setHeaders(input.getHeaders());
newInput.setParameters(input.getParameters());
inputRequestMap.put(s, newInput);
}
}
ArrayList<NodeOutput> leafOutputs = new ArrayList<>();
while (!zeroInDegree.isEmpty()) {
Set<String> readyToExecute = new HashSet<>(zeroInDegree);
readyToExecute.removeAll(executing);
executing.addAll(readyToExecute);
ArrayList<NodeOutput> outputs = new ArrayList<>();
if (topoSortedList == null) {
for (String name : readyToExecute) {
executorCompletionService.submit(
() ->
invokeModel(
name,
this.dag.getNodes().get(name).getWorkflowModel(),
inputRequestMap.get(name),
0));
}
try {
Future<NodeOutput> op = executorCompletionService.take();
if (op == null) {
throw new ExecutionException(
new RuntimeException("WorkflowNode result empty"));
} else {
outputs.add(op.get());
}
} catch (InterruptedException | ExecutionException e) {
logger.error(e.getMessage());
String[] error = e.getMessage().split(":");
throw new InternalServerException(error[error.length - 1]); // NOPMD
}
} else {
for (String name : readyToExecute) {
outputs.add(new NodeOutput(name, null));
}
}
for (NodeOutput output : outputs) {
String nodeName = output.getNodeName();
executing.remove(nodeName);
zeroInDegree.remove(nodeName);
if (topoSortedList != null) {
topoSortedList.add(nodeName);
}
Set<String> childNodes = this.dag.getDagMap().get(nodeName).get("outDegree");
if (childNodes.isEmpty()) {
leafOutputs.add(output);
} else {
for (String newNodeName : childNodes) {
if (topoSortedList == null) {
byte[] response = (byte[]) output.getData();
RequestInput newInput = this.inputRequestMap.get(newNodeName);
if (newInput == null) {
List<InputParameter> params = new ArrayList<>();
newInput = new RequestInput(UUID.randomUUID().toString());
if (inDegreeMap.get(newNodeName) == 1) {
params.add(new InputParameter("body", response));
} else {
params.add(new InputParameter(nodeName, response));
}
newInput.setParameters(params);
newInput.setHeaders(input.getHeaders());
} else {
newInput.addParameter(new InputParameter(nodeName, response));
}
this.inputRequestMap.put(newNodeName, newInput);
}
inDegreeMap.replace(newNodeName, inDegreeMap.get(newNodeName) - 1);
if (inDegreeMap.get(newNodeName) == 0) {
zeroInDegree.add(newNodeName);
}
}
}
}
}
return leafOutputs;
}