in frontend/server/src/main/java/org/pytorch/serve/ensemble/WorkFlow.java [32:155]
public WorkFlow(WorkflowArchive workflowArchive)
throws IOException, InvalidDAGException, InvalidWorkflowException {
this.workflowArchive = workflowArchive;
File specFile =
new File(
this.workflowArchive.getWorkflowDir(),
this.workflowArchive.getManifest().getWorkflow().getSpecFile());
File handlerFile =
new File(
this.workflowArchive.getWorkflowDir(),
this.workflowArchive.getManifest().getWorkflow().getHandler());
String workFlowName = this.workflowArchive.getWorkflowName();
Map<String, WorkflowModel> models = new HashMap<String, WorkflowModel>();
@SuppressWarnings("unchecked")
LinkedHashMap<String, Object> spec =
(LinkedHashMap<String, Object>) this.readSpecFile(specFile);
this.workflowSpec = spec;
@SuppressWarnings("unchecked")
Map<String, Object> modelsInfo = (Map<String, Object>) this.workflowSpec.get("models");
for (Map.Entry<String, Object> entry : modelsInfo.entrySet()) {
String keyName = entry.getKey();
switch (keyName) {
case "min-workers":
minWorkers = (int) entry.getValue();
break;
case "max-workers":
maxWorkers = (int) entry.getValue();
break;
case "batch-size":
batchSize = (int) entry.getValue();
break;
case "max-batch-delay":
maxBatchDelay = (int) entry.getValue();
break;
case "retry-attempts":
retryAttempts = (int) entry.getValue();
break;
case "timeout-ms":
timeOutMs = (int) entry.getValue();
break;
default:
// entry.getValue().getClass() check object type.
// assuming Map containing model info
@SuppressWarnings("unchecked")
LinkedHashMap<String, Object> model =
(LinkedHashMap<String, Object>) entry.getValue();
String modelName = workFlowName + "__" + keyName;
WorkflowModel wfm =
new WorkflowModel(
modelName,
(String) model.get("url"),
(int) model.getOrDefault("min-workers", minWorkers),
(int) model.getOrDefault("max-workers", maxWorkers),
(int) model.getOrDefault("batch-size", batchSize),
(int) model.getOrDefault("max-batch-delay", maxBatchDelay),
(int) model.getOrDefault("retry-attempts", retryAttempts),
(int) model.getOrDefault("timeout-ms", timeOutMs),
null);
models.put(modelName, wfm);
}
}
@SuppressWarnings("unchecked")
Map<String, Object> dagInfo = (Map<String, Object>) this.workflowSpec.get("dag");
for (Map.Entry<String, Object> entry : dagInfo.entrySet()) {
String nodeName = entry.getKey();
String modelName = workFlowName + "__" + nodeName;
WorkflowModel wfm;
if (!models.containsKey(modelName)) {
wfm =
new WorkflowModel(
modelName,
null,
1,
1,
1,
0,
retryAttempts,
timeOutMs,
handlerFile.getPath() + ":" + nodeName);
} else {
wfm = models.get(modelName);
}
Node fromNode = new Node(nodeName, wfm);
dag.addNode(fromNode);
@SuppressWarnings("unchecked")
ArrayList<String> values = (ArrayList<String>) entry.getValue();
for (String toNodeName : values) {
if (toNodeName == null || ("").equals(toNodeName.strip())) {
continue;
}
String toModelName = workFlowName + "__" + toNodeName;
WorkflowModel toWfm;
if (!models.containsKey(toModelName)) {
toWfm =
new WorkflowModel(
toModelName,
null,
1,
1,
1,
0,
retryAttempts,
timeOutMs,
handlerFile.getPath() + ":" + toNodeName);
} else {
toWfm = models.get(toModelName);
}
Node toNode = new Node(toNodeName, toWfm);
dag.addNode(toNode);
dag.addEdge(fromNode, toNode);
}
}
dag.validate();
}