public WorkFlow()

in frontend/server/src/main/java/org/pytorch/serve/ensemble/WorkFlow.java [32:155]


    public WorkFlow(WorkflowArchive workflowArchive)
            throws IOException, InvalidDAGException, InvalidWorkflowException {
        this.workflowArchive = workflowArchive;
        File specFile =
                new File(
                        this.workflowArchive.getWorkflowDir(),
                        this.workflowArchive.getManifest().getWorkflow().getSpecFile());
        File handlerFile =
                new File(
                        this.workflowArchive.getWorkflowDir(),
                        this.workflowArchive.getManifest().getWorkflow().getHandler());

        String workFlowName = this.workflowArchive.getWorkflowName();
        Map<String, WorkflowModel> models = new HashMap<String, WorkflowModel>();

        @SuppressWarnings("unchecked")
        LinkedHashMap<String, Object> spec =
                (LinkedHashMap<String, Object>) this.readSpecFile(specFile);
        this.workflowSpec = spec;

        @SuppressWarnings("unchecked")
        Map<String, Object> modelsInfo = (Map<String, Object>) this.workflowSpec.get("models");
        for (Map.Entry<String, Object> entry : modelsInfo.entrySet()) {
            String keyName = entry.getKey();

            switch (keyName) {
                case "min-workers":
                    minWorkers = (int) entry.getValue();
                    break;
                case "max-workers":
                    maxWorkers = (int) entry.getValue();
                    break;
                case "batch-size":
                    batchSize = (int) entry.getValue();
                    break;
                case "max-batch-delay":
                    maxBatchDelay = (int) entry.getValue();
                    break;
                case "retry-attempts":
                    retryAttempts = (int) entry.getValue();
                    break;
                case "timeout-ms":
                    timeOutMs = (int) entry.getValue();
                    break;
                default:
                    // entry.getValue().getClass() check object type.
                    // assuming Map containing model info
                    @SuppressWarnings("unchecked")
                    LinkedHashMap<String, Object> model =
                            (LinkedHashMap<String, Object>) entry.getValue();
                    String modelName = workFlowName + "__" + keyName;

                    WorkflowModel wfm =
                            new WorkflowModel(
                                    modelName,
                                    (String) model.get("url"),
                                    (int) model.getOrDefault("min-workers", minWorkers),
                                    (int) model.getOrDefault("max-workers", maxWorkers),
                                    (int) model.getOrDefault("batch-size", batchSize),
                                    (int) model.getOrDefault("max-batch-delay", maxBatchDelay),
                                    (int) model.getOrDefault("retry-attempts", retryAttempts),
                                    (int) model.getOrDefault("timeout-ms", timeOutMs),
                                    null);

                    models.put(modelName, wfm);
            }
        }

        @SuppressWarnings("unchecked")
        Map<String, Object> dagInfo = (Map<String, Object>) this.workflowSpec.get("dag");

        for (Map.Entry<String, Object> entry : dagInfo.entrySet()) {
            String nodeName = entry.getKey();
            String modelName = workFlowName + "__" + nodeName;
            WorkflowModel wfm;
            if (!models.containsKey(modelName)) {
                wfm =
                        new WorkflowModel(
                                modelName,
                                null,
                                1,
                                1,
                                1,
                                0,
                                retryAttempts,
                                timeOutMs,
                                handlerFile.getPath() + ":" + nodeName);
            } else {
                wfm = models.get(modelName);
            }
            Node fromNode = new Node(nodeName, wfm);
            dag.addNode(fromNode);

            @SuppressWarnings("unchecked")
            ArrayList<String> values = (ArrayList<String>) entry.getValue();
            for (String toNodeName : values) {

                if (toNodeName == null || ("").equals(toNodeName.strip())) {
                    continue;
                }
                String toModelName = workFlowName + "__" + toNodeName;
                WorkflowModel toWfm;
                if (!models.containsKey(toModelName)) {
                    toWfm =
                            new WorkflowModel(
                                    toModelName,
                                    null,
                                    1,
                                    1,
                                    1,
                                    0,
                                    retryAttempts,
                                    timeOutMs,
                                    handlerFile.getPath() + ":" + toNodeName);
                } else {
                    toWfm = models.get(toModelName);
                }
                Node toNode = new Node(toNodeName, toWfm);
                dag.addNode(toNode);
                dag.addEdge(fromNode, toNode);
            }
        }
        dag.validate();
    }