in plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java [138:207]
private void predict(
MLTask mlTask,
DataFrame inputDataFrame,
MLPredictionTaskRequest request,
ActionListener<MLTaskResponse> listener
) {
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
mlTaskManager.add(mlTask);
MLInput mlInput = request.getMlInput();
// search model by model id.
Model model = new Model();
if (request.getModelId() != null) {
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
client.get(getRequest, ActionListener.wrap(r -> {
if (r == null || !r.isExists()) {
listener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}
Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
// The backend roles of request user and resource user doesn't have intersection
OpenSearchException e = new OpenSearchException(
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
);
log.debug(e);
handlePredictFailure(mlTask, listener, e);
return;
}
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);
// run predict
MLOutput output;
try {
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setTaskId(mlTask.getTaskId());
((MLPredictionOutput) output).setStatus(mlTaskManager.get(mlTask.getTaskId()).getState().name());
}
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleMLTaskComplete(mlTask);
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + model.getName(), e);
handlePredictFailure(mlTask, listener, e);
return;
}
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
listener.onResponse(response);
}, e -> {
log.error("Failed to predict model " + mlTask.getModelId(), e);
listener.onFailure(e);
}));
} else {
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
log.error("ModelId is invalid", e);
handlePredictFailure(mlTask, listener, e);
return;
}
}