in src/main/java/org/opensearch/ad/ml/EntityColdStarter.java [213:316]
private void coldStart(
String modelId,
Entity entity,
String detectorId,
ModelState<EntityModel> modelState,
AnomalyDetector detector,
ActionListener<Void> listener
) {
logger.debug("Trigger cold start for {}", modelId);
if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) {
listener.onResponse(null);
return;
}
boolean earlyExit = true;
try {
DoorKeeper doorKeeper = doorKeepers
.computeIfAbsent(
detectorId,
id -> {
// reset every 60 intervals
return new DoorKeeper(
AnomalyDetectorSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION,
AnomalyDetectorSettings.DOOR_KEEPER_FAULSE_POSITIVE_RATE,
detector.getDetectionIntervalDuration().multipliedBy(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
clock
);
}
);
// Won't retry cold start within 60 intervals for an entity
if (doorKeeper.mightContain(modelId)) {
return;
}
doorKeeper.put(modelId);
ActionListener<Optional<List<double[][]>>> coldStartCallBack = ActionListener.wrap(trainingData -> {
try {
if (trainingData.isPresent()) {
List<double[][]> dataPoints = trainingData.get();
combineTrainSamples(dataPoints, modelId, modelState);
Queue<double[]> samples = modelState.getModel().getSamples();
// only train models if we have enough samples
if (samples.size() >= numMinSamples) {
// The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by
// multiple places so I want to make the saving model implicit just in case I forgot.
trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize());
logger.info("Succeeded in training entity: {}", modelId);
} else {
// save to checkpoint
checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM);
logger.info("Not enough data to train entity: {}, currently we have {}", modelId, samples.size());
}
} else {
logger.info("Cannot get training data for {}", modelId);
}
listener.onResponse(null);
} catch (Exception e) {
listener.onFailure(e);
}
}, exception -> {
try {
logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception);
Throwable cause = Throwables.getRootCause(exception);
if (ExceptionUtil.isOverloaded(cause)) {
logger.error("too many requests");
lastThrottledColdStartTime = Instant.now();
} else if (cause instanceof AnomalyDetectionException || exception instanceof AnomalyDetectionException) {
// e.g., cannot find anomaly detector
nodeStateManager.setException(detectorId, exception);
} else {
nodeStateManager.setException(detectorId, new AnomalyDetectionException(detectorId, cause));
}
listener.onFailure(exception);
} catch (Exception e) {
listener.onFailure(e);
}
});
threadPool
.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)
.execute(
() -> getEntityColdStartData(
detectorId,
entity,
new ThreadedActionListener<>(
logger,
threadPool,
AnomalyDetectorPlugin.AD_THREAD_POOL_NAME,
coldStartCallBack,
false
)
)
);
earlyExit = false;
} finally {
if (earlyExit) {
listener.onResponse(null);
}
}
}