public ArrayList execute()

in frontend/server/src/main/java/org/pytorch/serve/ensemble/DagExecutor.java [41:145]


    public ArrayList<NodeOutput> execute(RequestInput input, ArrayList<String> topoSortedList) {

        CompletionService<NodeOutput> executorCompletionService = null;
        if (topoSortedList == null) {
            ExecutorService executorService = Executors.newFixedThreadPool(4);
            executorCompletionService = new ExecutorCompletionService<>(executorService);
        }

        Map<String, Integer> inDegreeMap = this.dag.getInDegreeMap();
        Set<String> zeroInDegree = dag.getStartNodeNames();
        Set<String> executing = new HashSet<>();

        if (topoSortedList == null) {
            for (String s : zeroInDegree) {
                RequestInput newInput = new RequestInput(UUID.randomUUID().toString());
                newInput.setHeaders(input.getHeaders());
                newInput.setParameters(input.getParameters());
                inputRequestMap.put(s, newInput);
            }
        }

        ArrayList<NodeOutput> leafOutputs = new ArrayList<>();

        while (!zeroInDegree.isEmpty()) {
            Set<String> readyToExecute = new HashSet<>(zeroInDegree);
            readyToExecute.removeAll(executing);
            executing.addAll(readyToExecute);

            ArrayList<NodeOutput> outputs = new ArrayList<>();
            if (topoSortedList == null) {
                for (String name : readyToExecute) {
                    executorCompletionService.submit(
                            () ->
                                    invokeModel(
                                            name,
                                            this.dag.getNodes().get(name).getWorkflowModel(),
                                            inputRequestMap.get(name),
                                            0));
                }

                try {
                    Future<NodeOutput> op = executorCompletionService.take();
                    if (op == null) {
                        throw new ExecutionException(
                                new RuntimeException("WorkflowNode result empty"));
                    } else {
                        outputs.add(op.get());
                    }
                } catch (InterruptedException | ExecutionException e) {
                    logger.error(e.getMessage());
                    String[] error = e.getMessage().split(":");
                    throw new InternalServerException(error[error.length - 1]); // NOPMD
                }
            } else {
                for (String name : readyToExecute) {
                    outputs.add(new NodeOutput(name, null));
                }
            }

            for (NodeOutput output : outputs) {
                String nodeName = output.getNodeName();
                executing.remove(nodeName);
                zeroInDegree.remove(nodeName);

                if (topoSortedList != null) {
                    topoSortedList.add(nodeName);
                }

                Set<String> childNodes = this.dag.getDagMap().get(nodeName).get("outDegree");
                if (childNodes.isEmpty()) {
                    leafOutputs.add(output);
                } else {
                    for (String newNodeName : childNodes) {

                        if (topoSortedList == null) {
                            byte[] response = (byte[]) output.getData();

                            RequestInput newInput = this.inputRequestMap.get(newNodeName);
                            if (newInput == null) {
                                List<InputParameter> params = new ArrayList<>();
                                newInput = new RequestInput(UUID.randomUUID().toString());
                                if (inDegreeMap.get(newNodeName) == 1) {
                                    params.add(new InputParameter("body", response));
                                } else {
                                    params.add(new InputParameter(nodeName, response));
                                }
                                newInput.setParameters(params);
                                newInput.setHeaders(input.getHeaders());
                            } else {
                                newInput.addParameter(new InputParameter(nodeName, response));
                            }
                            this.inputRequestMap.put(newNodeName, newInput);
                        }

                        inDegreeMap.replace(newNodeName, inDegreeMap.get(newNodeName) - 1);
                        if (inDegreeMap.get(newNodeName) == 0) {
                            zeroInDegree.add(newNodeName);
                        }
                    }
                }
            }
        }

        return leafOutputs;
    }