in src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java [46:91]
protected void doExecute(Task task, TrainingModelRequest request,
ActionListener<TrainingModelResponse> listener) {
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext =
new NativeMemoryEntryContext.TrainingDataEntryContext(
request.getTrainingDataSizeInKB(),
request.getTrainingIndex(),
request.getTrainingField(),
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(),
clusterService,
request.getMaximumVectorCount(),
request.getSearchSize()
);
// Allocation representing size model will occupy in memory during training
NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext =
new NativeMemoryEntryContext.AnonymousEntryContext(
request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()),
NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance()
);
TrainingJob trainingJob = new TrainingJob(
request.getModelId(),
request.getKnnMethodContext(),
NativeMemoryCacheManager.getInstance(),
trainingDataEntryContext,
modelAnonymousEntryContext,
request.getDimension(),
request.getDescription()
);
KNNCounter.TRAINING_REQUESTS.increment();
ActionListener<TrainingModelResponse> wrappedListener = ActionListener.wrap(listener::onResponse, ex -> {
KNNCounter.TRAINING_ERRORS.increment();
listener.onFailure(ex);
});
try {
TrainingJobRunner.getInstance().execute(trainingJob, ActionListener.wrap(
indexResponse -> wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())),
wrappedListener::onFailure)
);
} catch (IOException e) {
wrappedListener.onFailure(e);
}
}