lib/model/unittest/ModelTestHelpers.h (136 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#ifndef INCLUDED_ml_model_ModelTestHelpers_h
#define INCLUDED_ml_model_ModelTestHelpers_h
#include <core/CJsonStatePersistInserter.h>
#include <core/CJsonStateRestoreTraverser.h>
#include <model/CDataGatherer.h>
#include <model/CSearchKey.h>
#include <model/ModelTypes.h>
#include <boost/test/unit_test.hpp>
namespace ml {
namespace model {
const CSearchKey KEY;
const std::string EMPTY_STRING;
static void testPersistence(const SModelParams& params,
const CDataGatherer& origGatherer,
model_t::EAnalysisCategory category) {
// Test persistence. (We check for idempotency.)
std::ostringstream origJson;
core::CJsonStatePersistInserter::persist(
origJson, [&origGatherer](core::CJsonStatePersistInserter& inserter) {
origGatherer.acceptPersistInserter(inserter);
});
LOG_DEBUG(<< "gatherer JSON size " << origJson.str().size());
LOG_TRACE(<< "gatherer JSON representation:\n" << origJson.str());
// Restore the JSON into a new filter
// The traverser expects the state json in a embedded document
std::istringstream origJsonStrm{"{\"topLevel\" : " + origJson.str() + "}"};
core::CJsonStateRestoreTraverser traverser(origJsonStrm);
CBucketGatherer::SBucketGathererInitData bucketGathererInitData{
EMPTY_STRING, EMPTY_STRING, EMPTY_STRING, EMPTY_STRING, {}, 0, 0};
CDataGatherer restoredGatherer(category, model_t::E_None, params, EMPTY_STRING,
KEY, bucketGathererInitData, traverser);
BOOST_REQUIRE_EQUAL(origGatherer.checksum(), restoredGatherer.checksum());
// The JSON representation of the new filter should be the
// same as the original
std::ostringstream newJson;
core::CJsonStatePersistInserter::persist(
newJson, [&restoredGatherer](core::CJsonStatePersistInserter& inserter) {
restoredGatherer.acceptPersistInserter(inserter);
});
BOOST_REQUIRE_EQUAL(origJson.str(), newJson.str());
}
static void testGathererAttributes(const CDataGatherer& gatherer,
core_t::TTime startTime,
core_t::TTime bucketLength) {
BOOST_REQUIRE_EQUAL(1, gatherer.numberActivePeople());
BOOST_REQUIRE_EQUAL(1, gatherer.numberByFieldValues());
BOOST_REQUIRE_EQUAL(std::string("p"), gatherer.personName(0));
BOOST_REQUIRE_EQUAL(std::string("-"), gatherer.personName(1));
std::size_t pid;
BOOST_TEST_REQUIRE(gatherer.personId("p", pid));
BOOST_REQUIRE_EQUAL(0, pid);
BOOST_TEST_REQUIRE(!gatherer.personId("a.n.other p", pid));
BOOST_REQUIRE_EQUAL(0, gatherer.numberActiveAttributes());
BOOST_REQUIRE_EQUAL(0, gatherer.numberOverFieldValues());
BOOST_REQUIRE_EQUAL(startTime, gatherer.currentBucketStartTime());
BOOST_REQUIRE_EQUAL(bucketLength, gatherer.bucketLength());
}
class CDataGathererBuilder {
public:
using TFeatureVec = CDataGatherer::TFeatureVec;
using TStrVec = CDataGatherer::TStrVec;
public:
CDataGathererBuilder(model_t::EAnalysisCategory gathererType,
const TFeatureVec& features,
const SModelParams& params,
const CSearchKey& searchKey,
const core_t::TTime startTime)
: m_Features(features), m_Params(params), m_StartTime(startTime),
m_SearchKey(searchKey), m_GathererType(gathererType) {}
CDataGatherer build() const {
CBucketGatherer::SBucketGathererInitData bucketGathererInitData{
m_SummaryCountFieldName,
m_PersonFieldName,
m_AttributeFieldName,
m_ValueFieldName,
m_InfluenceFieldNames,
m_StartTime,
static_cast<unsigned int>(m_SampleCountOverride)};
return {m_GathererType, m_SummaryMode, m_Params,
m_PartitionFieldValue, m_SearchKey, m_Features,
bucketGathererInitData};
}
std::shared_ptr<CDataGatherer> buildSharedPtr() const {
CBucketGatherer::SBucketGathererInitData bucketGathererInitData{
m_SummaryCountFieldName,
m_PersonFieldName,
m_AttributeFieldName,
m_ValueFieldName,
m_InfluenceFieldNames,
m_StartTime,
static_cast<unsigned int>(m_SampleCountOverride)};
return std::make_shared<CDataGatherer>(m_GathererType, m_SummaryMode, m_Params,
m_PartitionFieldValue, m_SearchKey,
m_Features, bucketGathererInitData);
}
CDataGathererBuilder& partitionFieldValue(std::string_view partitionFieldValue) {
m_PartitionFieldValue = partitionFieldValue;
return *this;
}
CDataGathererBuilder& personFieldName(std::string_view personFieldName) {
m_PersonFieldName = personFieldName;
return *this;
}
CDataGathererBuilder& valueFieldName(std::string_view valueFieldName) {
m_ValueFieldName = valueFieldName;
return *this;
}
CDataGathererBuilder& influenceFieldNames(const TStrVec& influenceFieldName) {
m_InfluenceFieldNames = influenceFieldName;
return *this;
}
CDataGathererBuilder& attributeFieldName(std::string_view attributeFieldName) {
m_AttributeFieldName = attributeFieldName;
return *this;
}
CDataGathererBuilder& gathererType(model_t::EAnalysisCategory gathererType) {
m_GathererType = gathererType;
return *this;
}
CDataGathererBuilder& sampleCountOverride(std::size_t sampleCount) {
m_SampleCountOverride = static_cast<int>(sampleCount);
return *this;
}
private:
const TFeatureVec& m_Features;
const SModelParams& m_Params;
core_t::TTime m_StartTime;
const CSearchKey& m_SearchKey;
model_t::EAnalysisCategory m_GathererType;
model_t::ESummaryMode m_SummaryMode{model_t::E_None};
std::string m_SummaryCountFieldName{EMPTY_STRING};
std::string m_PartitionFieldValue{EMPTY_STRING};
std::string m_PersonFieldName{EMPTY_STRING};
std::string m_AttributeFieldName{EMPTY_STRING};
std::string m_ValueFieldName{EMPTY_STRING};
TStrVec m_InfluenceFieldNames;
int m_SampleCountOverride{0};
};
}
}
#endif // INCLUDED_ml_model_ModelTestHelpers_h