private void predict()

in plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java [138:207]


    private void predict(
        MLTask mlTask,
        DataFrame inputDataFrame,
        MLPredictionTaskRequest request,
        ActionListener<MLTaskResponse> listener
    ) {
        // track ML task count and add ML task into cache
        mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
        mlTaskManager.add(mlTask);

        MLInput mlInput = request.getMlInput();
        // search model by model id.
        Model model = new Model();
        if (request.getModelId() != null) {
            GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
            client.get(getRequest, ActionListener.wrap(r -> {
                if (r == null || !r.isExists()) {
                    listener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
                    return;
                }
                Map<String, Object> source = r.getSourceAsMap();
                User requestUser = getUserContext(client);
                User resourceUser = User.parse((String) source.get(USER));
                if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
                    // The backend roles of request user and resource user doesn't have intersection
                    OpenSearchException e = new OpenSearchException(
                        "User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
                    );
                    log.debug(e);
                    handlePredictFailure(mlTask, listener, e);
                    return;
                }

                model.setName((String) source.get(MLModel.MODEL_NAME));
                model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
                byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
                model.setContent(decoded);

                // run predict
                MLOutput output;
                try {
                    mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
                    output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
                    if (output instanceof MLPredictionOutput) {
                        ((MLPredictionOutput) output).setTaskId(mlTask.getTaskId());
                        ((MLPredictionOutput) output).setStatus(mlTaskManager.get(mlTask.getTaskId()).getState().name());
                    }

                    // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
                    handleMLTaskComplete(mlTask);
                } catch (Exception e) {
                    // todo need to specify what exception
                    log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + model.getName(), e);
                    handlePredictFailure(mlTask, listener, e);
                    return;
                }

                MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                listener.onResponse(response);
            }, e -> {
                log.error("Failed to predict model " + mlTask.getModelId(), e);
                listener.onFailure(e);
            }));
        } else {
            IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
            log.error("ModelId is invalid", e);
            handlePredictFailure(mlTask, listener, e);
            return;
        }
    }