in plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java [102:136]
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask
.builder()
.taskId(UUID.randomUUID().toString())
.modelId(request.getModelId())
.taskType(MLTaskType.PREDICTION)
.inputType(inputDataType)
.functionName(request.getMlInput().getFunctionName())
.state(MLTaskState.CREATED)
.workerNode(clusterService.localNode().getId())
.createTime(now)
.lastUpdateTime(now)
.async(false)
.build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener
.wrap(dataFrame -> { predict(mlTask, dataFrame, request, listener); }, e -> {
log.error("Failed to generate DataFrame from search query", e);
mlTaskManager.addIfAbsent(mlTask);
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler
.parseSearchQueryInput(
mlInput.getInputDataset(),
new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false)
);
} else {
DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset());
threadPool.executor(TASK_THREAD_POOL).execute(() -> { predict(mlTask, inputDataFrame, request, listener); });
}
}