in frontend/server/src/main/java/org/pytorch/serve/workflow/WorkflowManager.java [106:220]
public StatusResponse registerWorkflow(
String workflowName,
String url,
int responseTimeout,
boolean synchronous,
boolean s3SseKms)
throws WorkflowException {
if (url == null) {
throw new BadRequestException("Parameter url is required.");
}
StatusResponse status = new StatusResponse();
ExecutorService executorService = Executors.newFixedThreadPool(4);
CompletionService<ModelRegistrationResult> executorCompletionService =
new ExecutorCompletionService<>(executorService);
boolean failed = false;
ArrayList<String> failedMessages = new ArrayList<>();
ArrayList<String> successNodes = new ArrayList<>();
try {
WorkflowArchive archive = createWorkflowArchive(workflowName, url);
WorkFlow workflow = createWorkflow(archive);
if (workflowMap.get(workflow.getWorkflowArchive().getWorkflowName()) != null) {
throw new ConflictStatusException(
"Workflow "
+ workflow.getWorkflowArchive().getWorkflowName()
+ " is already registered.");
}
Map<String, Node> nodes = workflow.getDag().getNodes();
List<Future<ModelRegistrationResult>> futures = new ArrayList<>();
for (Map.Entry<String, Node> entry : nodes.entrySet()) {
Node node = entry.getValue();
WorkflowModel wfm = node.getWorkflowModel();
futures.add(
executorCompletionService.submit(
() -> registerModelWrapper(wfm, responseTimeout, synchronous)));
}
int i = 0;
while (i < futures.size()) {
i++;
Future<ModelRegistrationResult> future = executorCompletionService.take();
ModelRegistrationResult result = future.get();
if (result.getResponse().getHttpResponseCode() != HttpURLConnection.HTTP_OK) {
failed = true;
String msg;
if (result.getResponse().getStatus() == null) {
msg =
"Failed to register the model "
+ result.getModelName()
+ ". Check error logs.";
} else {
msg = result.getResponse().getStatus();
}
failedMessages.add(msg);
} else {
successNodes.add(result.getModelName());
}
}
if (failed) {
String rollbackFailure = null;
try {
removeArtifacts(workflowName, workflow, successNodes);
} catch (Exception e) {
rollbackFailure =
"Error while doing rollback of failed workflow. Details"
+ e.getMessage();
}
if (rollbackFailure != null) {
failedMessages.add(rollbackFailure);
}
status.setHttpResponseCode(HttpURLConnection.HTTP_INTERNAL_ERROR);
String message =
String.format(
"Workflow %s has failed to register. Failures: %s",
workflow.getWorkflowArchive().getWorkflowName(),
failedMessages.toString());
status.setStatus(message);
status.setE(new WorkflowException(message));
} else {
status.setHttpResponseCode(HttpURLConnection.HTTP_OK);
status.setStatus(
String.format(
"Workflow %s has been registered and scaled successfully.",
workflow.getWorkflowArchive().getWorkflowName()));
workflowMap.putIfAbsent(workflow.getWorkflowArchive().getWorkflowName(), workflow);
}
} catch (DownloadArchiveException e) {
status.setHttpResponseCode(HttpURLConnection.HTTP_BAD_REQUEST);
status.setStatus("Failed to download workflow archive file");
status.setE(e);
} catch (InvalidDAGException e) {
status.setHttpResponseCode(HttpURLConnection.HTTP_BAD_REQUEST);
status.setStatus("Invalid workflow specification");
status.setE(e);
} catch (InterruptedException | ExecutionException | IOException e) {
status.setHttpResponseCode(HttpURLConnection.HTTP_INTERNAL_ERROR);
status.setStatus("Failed to register workflow.");
status.setE(e);
} finally {
executorService.shutdown();
}
return status;
}