public void createMLTaskAndTrain()

in plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java [97:136]


    public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        MLTask mlTask = MLTask
            .builder()
            .taskType(MLTaskType.TRAINING)
            .inputType(inputDataType)
            .functionName(request.getMlInput().getFunctionName())
            .state(MLTaskState.CREATED)
            .workerNode(clusterService.localNode().getId())
            .createTime(now)
            .lastUpdateTime(now)
            .async(request.isAsync())
            .build();

        if (request.isAsync()) {
            mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
                String taskId = r.getId();
                mlTask.setTaskId(taskId);
                if (mlTask.isAsync()) {
                    listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
                    ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
                        String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
                        log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
                        mlTask.setModelId(modelId);
                        handleMLTaskComplete(mlTask);
                    }, ex -> {
                        log.error("Failed to train ML model for task " + taskId);
                        handleMLTaskFailure(mlTask, ex);
                    });
                    startTrainingTask(mlTask, request.getMlInput(), internalListener);
                } else {
                    startTrainingTask(mlTask, request.getMlInput(), listener);
                }
            }, e -> { listener.onFailure(e); }));
        } else {
            mlTask.setTaskId(UUID.randomUUID().toString());
            startTrainingTask(mlTask, request.getMlInput(), listener);
        }
    }