lib/model/CCountingModelFactory.cc (134 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <model/CCountingModelFactory.h> #include <maths/common/CConstantPrior.h> #include <maths/common/CMultivariateConstantPrior.h> #include <model/CCountingModel.h> #include <model/CDataGatherer.h> #include <model/CSearchKey.h> #include <memory> namespace ml { namespace model { CCountingModelFactory::CCountingModelFactory(const SModelParams& params, const TInterimBucketCorrectorWPtr& interimBucketCorrector, model_t::ESummaryMode summaryMode, const std::string& summaryCountFieldName) : CModelFactory(params, interimBucketCorrector), m_DetectorIndex(), m_SummaryMode(summaryMode), m_SummaryCountFieldName(summaryCountFieldName), m_UseNull(false) { } CCountingModelFactory* CCountingModelFactory::clone() const { return new CCountingModelFactory(*this); } CAnomalyDetectorModel* CCountingModelFactory::makeModel(const SModelInitializationData& initData) const { TDataGathererPtr dataGatherer = initData.s_DataGatherer; if (!dataGatherer) { LOG_ERROR(<< "NULL data gatherer"); return nullptr; } return new CCountingModel(this->modelParams(), dataGatherer, this->interimBucketCorrector()); } CAnomalyDetectorModel* CCountingModelFactory::makeModel(const SModelInitializationData& initData, core::CStateRestoreTraverser& traverser) const { TDataGathererPtr dataGatherer = initData.s_DataGatherer; if (!dataGatherer) { LOG_ERROR(<< "NULL data gatherer"); return nullptr; } return new CCountingModel(this->modelParams(), dataGatherer, this->interimBucketCorrector(), traverser); } CModelFactory::TDataGathererPtr CCountingModelFactory::makeDataGatherer(const SGathererInitializationData& initData) const { const CBucketGatherer::SBucketGathererInitData bucketGathererInitData{ m_SummaryCountFieldName, m_PersonFieldName, EMPTY_STRING, EMPTY_STRING, {}, initData.s_StartTime, 0}; return std::make_shared<CDataGatherer>( model_t::E_EventRate, m_SummaryMode, this->modelParams(), initData.s_PartitionFieldValue, this->searchKey(), m_Features, bucketGathererInitData); } CModelFactory::TDataGathererPtr CCountingModelFactory::makeDataGatherer(const std::string& partitionFieldValue, core::CStateRestoreTraverser& traverser) const { CBucketGatherer::SBucketGathererInitData bucketGathererInitData{ m_SummaryCountFieldName, m_PersonFieldName, EMPTY_STRING, EMPTY_STRING, {}, 0, 0}; return std::make_shared<CDataGatherer>( model_t::E_EventRate, m_SummaryMode, this->modelParams(), partitionFieldValue, this->searchKey(), bucketGathererInitData, traverser); } CCountingModelFactory::TPriorPtr CCountingModelFactory::defaultPrior(model_t::EFeature /*feature*/, const SModelParams& /*params*/) const { return std::make_unique<maths::common::CConstantPrior>(); } CCountingModelFactory::TMultivariatePriorUPtr CCountingModelFactory::defaultMultivariatePrior(model_t::EFeature feature, const SModelParams& /*params*/) const { return std::make_unique<maths::common::CMultivariateConstantPrior>( model_t::dimension(feature)); } CCountingModelFactory::TMultivariatePriorUPtr CCountingModelFactory::defaultCorrelatePrior(model_t::EFeature /*feature*/, const SModelParams& /*params*/) const { return std::make_unique<maths::common::CMultivariateConstantPrior>(2); } const CSearchKey& CCountingModelFactory::searchKey() const { if (m_SearchKeyCache == std::nullopt) { m_SearchKeyCache.emplace(m_DetectorIndex, function_t::function(m_Features), m_UseNull, this->modelParams().s_ExcludeFrequent, "", m_PersonFieldName, "", m_PartitionFieldName); } return *m_SearchKeyCache; } bool CCountingModelFactory::isSimpleCount() const { return CSearchKey::isSimpleCount(function_t::function(m_Features), m_PersonFieldName); } model_t::ESummaryMode CCountingModelFactory::summaryMode() const { return m_SummaryMode; } maths_t::EDataType CCountingModelFactory::dataType() const { return maths_t::E_IntegerData; } void CCountingModelFactory::detectorIndex(int detectorIndex) { m_DetectorIndex = detectorIndex; m_SearchKeyCache.reset(); } void CCountingModelFactory::fieldNames(const std::string& partitionFieldName, const std::string& /*overFieldName*/, const std::string& byFieldName, const std::string& /*valueFieldName*/, const TStrVec& /*influenceFieldNames*/) { m_PartitionFieldName = partitionFieldName; m_PersonFieldName = byFieldName; m_SearchKeyCache.reset(); } void CCountingModelFactory::useNull(bool useNull) { m_UseNull = useNull; m_SearchKeyCache.reset(); } void CCountingModelFactory::features(const TFeatureVec& features) { m_Features = features; m_SearchKeyCache.reset(); } CCountingModelFactory::TStrCRefVec CCountingModelFactory::partitioningFields() const { TStrCRefVec result; result.reserve(2); if (!m_PartitionFieldName.empty()) { result.emplace_back(m_PartitionFieldName); } if (!m_PersonFieldName.empty()) { result.emplace_back(m_PersonFieldName); } return result; } double CCountingModelFactory::minimumSeasonalVarianceScale() const { // unused, return something return 0.0; } } }