protected void doExecute()

in src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java [46:91]


    protected void doExecute(Task task, TrainingModelRequest request,
                             ActionListener<TrainingModelResponse> listener) {

        NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext =
                new NativeMemoryEntryContext.TrainingDataEntryContext(
                        request.getTrainingDataSizeInKB(),
                        request.getTrainingIndex(),
                        request.getTrainingField(),
                        NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(),
                        clusterService,
                        request.getMaximumVectorCount(),
                        request.getSearchSize()
                );

        // Allocation representing size model will occupy in memory during training
        NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext =
                new NativeMemoryEntryContext.AnonymousEntryContext(
                        request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()),
                        NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance()
                );

        TrainingJob trainingJob = new TrainingJob(
                request.getModelId(),
                request.getKnnMethodContext(),
                NativeMemoryCacheManager.getInstance(),
                trainingDataEntryContext,
                modelAnonymousEntryContext,
                request.getDimension(),
                request.getDescription()
        );

        KNNCounter.TRAINING_REQUESTS.increment();
        ActionListener<TrainingModelResponse> wrappedListener = ActionListener.wrap(listener::onResponse, ex -> {
            KNNCounter.TRAINING_ERRORS.increment();
            listener.onFailure(ex);
        });

        try {
            TrainingJobRunner.getInstance().execute(trainingJob, ActionListener.wrap(
                    indexResponse -> wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())),
                    wrappedListener::onFailure)
            );
        } catch (IOException e) {
            wrappedListener.onFailure(e);
        }
    }