in plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java [97:136]
public void createMLTaskAndTrain(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask
.builder()
.taskType(MLTaskType.TRAINING)
.inputType(inputDataType)
.functionName(request.getMlInput().getFunctionName())
.state(MLTaskState.CREATED)
.workerNode(clusterService.localNode().getId())
.createTime(now)
.lastUpdateTime(now)
.async(request.isAsync())
.build();
if (request.isAsync()) {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(r -> {
String taskId = r.getId();
mlTask.setTaskId(taskId);
if (mlTask.isAsync()) {
listener.onResponse(new MLTaskResponse(new MLTrainingOutput(null, taskId, mlTask.getState().name())));
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(res -> {
String modelId = ((MLTrainingOutput) res.getOutput()).getModelId();
log.info("ML model trained successfully, task id: {}, model id: {}", taskId, modelId);
mlTask.setModelId(modelId);
handleMLTaskComplete(mlTask);
}, ex -> {
log.error("Failed to train ML model for task " + taskId);
handleMLTaskFailure(mlTask, ex);
});
startTrainingTask(mlTask, request.getMlInput(), internalListener);
} else {
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}, e -> { listener.onFailure(e); }));
} else {
mlTask.setTaskId(UUID.randomUUID().toString());
startTrainingTask(mlTask, request.getMlInput(), listener);
}
}