lib/model/CCountingModelFactory.cc (134 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <model/CCountingModelFactory.h>
#include <maths/common/CConstantPrior.h>
#include <maths/common/CMultivariateConstantPrior.h>
#include <model/CCountingModel.h>
#include <model/CDataGatherer.h>
#include <model/CSearchKey.h>
#include <memory>
namespace ml {
namespace model {
CCountingModelFactory::CCountingModelFactory(const SModelParams& params,
const TInterimBucketCorrectorWPtr& interimBucketCorrector,
model_t::ESummaryMode summaryMode,
const std::string& summaryCountFieldName)
: CModelFactory(params, interimBucketCorrector), m_DetectorIndex(),
m_SummaryMode(summaryMode),
m_SummaryCountFieldName(summaryCountFieldName), m_UseNull(false) {
}
CCountingModelFactory* CCountingModelFactory::clone() const {
return new CCountingModelFactory(*this);
}
CAnomalyDetectorModel*
CCountingModelFactory::makeModel(const SModelInitializationData& initData) const {
TDataGathererPtr dataGatherer = initData.s_DataGatherer;
if (!dataGatherer) {
LOG_ERROR(<< "NULL data gatherer");
return nullptr;
}
return new CCountingModel(this->modelParams(), dataGatherer,
this->interimBucketCorrector());
}
CAnomalyDetectorModel*
CCountingModelFactory::makeModel(const SModelInitializationData& initData,
core::CStateRestoreTraverser& traverser) const {
TDataGathererPtr dataGatherer = initData.s_DataGatherer;
if (!dataGatherer) {
LOG_ERROR(<< "NULL data gatherer");
return nullptr;
}
return new CCountingModel(this->modelParams(), dataGatherer,
this->interimBucketCorrector(), traverser);
}
CModelFactory::TDataGathererPtr
CCountingModelFactory::makeDataGatherer(const SGathererInitializationData& initData) const {
const CBucketGatherer::SBucketGathererInitData bucketGathererInitData{
m_SummaryCountFieldName,
m_PersonFieldName,
EMPTY_STRING,
EMPTY_STRING,
{},
initData.s_StartTime,
0};
return std::make_shared<CDataGatherer>(
model_t::E_EventRate, m_SummaryMode, this->modelParams(), initData.s_PartitionFieldValue,
this->searchKey(), m_Features, bucketGathererInitData);
}
CModelFactory::TDataGathererPtr
CCountingModelFactory::makeDataGatherer(const std::string& partitionFieldValue,
core::CStateRestoreTraverser& traverser) const {
CBucketGatherer::SBucketGathererInitData bucketGathererInitData{
m_SummaryCountFieldName, m_PersonFieldName, EMPTY_STRING, EMPTY_STRING, {}, 0, 0};
return std::make_shared<CDataGatherer>(
model_t::E_EventRate, m_SummaryMode, this->modelParams(),
partitionFieldValue, this->searchKey(), bucketGathererInitData, traverser);
}
CCountingModelFactory::TPriorPtr
CCountingModelFactory::defaultPrior(model_t::EFeature /*feature*/,
const SModelParams& /*params*/) const {
return std::make_unique<maths::common::CConstantPrior>();
}
CCountingModelFactory::TMultivariatePriorUPtr
CCountingModelFactory::defaultMultivariatePrior(model_t::EFeature feature,
const SModelParams& /*params*/) const {
return std::make_unique<maths::common::CMultivariateConstantPrior>(
model_t::dimension(feature));
}
CCountingModelFactory::TMultivariatePriorUPtr
CCountingModelFactory::defaultCorrelatePrior(model_t::EFeature /*feature*/,
const SModelParams& /*params*/) const {
return std::make_unique<maths::common::CMultivariateConstantPrior>(2);
}
const CSearchKey& CCountingModelFactory::searchKey() const {
if (m_SearchKeyCache == std::nullopt) {
m_SearchKeyCache.emplace(m_DetectorIndex, function_t::function(m_Features),
m_UseNull, this->modelParams().s_ExcludeFrequent,
"", m_PersonFieldName, "", m_PartitionFieldName);
}
return *m_SearchKeyCache;
}
bool CCountingModelFactory::isSimpleCount() const {
return CSearchKey::isSimpleCount(function_t::function(m_Features), m_PersonFieldName);
}
model_t::ESummaryMode CCountingModelFactory::summaryMode() const {
return m_SummaryMode;
}
maths_t::EDataType CCountingModelFactory::dataType() const {
return maths_t::E_IntegerData;
}
void CCountingModelFactory::detectorIndex(int detectorIndex) {
m_DetectorIndex = detectorIndex;
m_SearchKeyCache.reset();
}
void CCountingModelFactory::fieldNames(const std::string& partitionFieldName,
const std::string& /*overFieldName*/,
const std::string& byFieldName,
const std::string& /*valueFieldName*/,
const TStrVec& /*influenceFieldNames*/) {
m_PartitionFieldName = partitionFieldName;
m_PersonFieldName = byFieldName;
m_SearchKeyCache.reset();
}
void CCountingModelFactory::useNull(bool useNull) {
m_UseNull = useNull;
m_SearchKeyCache.reset();
}
void CCountingModelFactory::features(const TFeatureVec& features) {
m_Features = features;
m_SearchKeyCache.reset();
}
CCountingModelFactory::TStrCRefVec CCountingModelFactory::partitioningFields() const {
TStrCRefVec result;
result.reserve(2);
if (!m_PartitionFieldName.empty()) {
result.emplace_back(m_PartitionFieldName);
}
if (!m_PersonFieldName.empty()) {
result.emplace_back(m_PersonFieldName);
}
return result;
}
double CCountingModelFactory::minimumSeasonalVarianceScale() const {
// unused, return something
return 0.0;
}
}
}