lib/maths/common/CPriorStateSerialiser.cc (199 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <maths/common/CPriorStateSerialiser.h>
#include <core/CLogger.h>
#include <core/CStatePersistInserter.h>
#include <core/CStateRestoreTraverser.h>
#include <maths/common/CConstantPrior.h>
#include <maths/common/CGammaRateConjugate.h>
#include <maths/common/CLogNormalMeanPrecConjugate.h>
#include <maths/common/CMultimodalPrior.h>
#include <maths/common/CMultinomialConjugate.h>
#include <maths/common/CMultivariateConstantPrior.h>
#include <maths/common/CMultivariateMultimodalPriorFactory.h>
#include <maths/common/CMultivariateNormalConjugateFactory.h>
#include <maths/common/CMultivariateOneOfNPriorFactory.h>
#include <maths/common/CMultivariatePrior.h>
#include <maths/common/CNormalMeanPrecConjugate.h>
#include <maths/common/COneOfNPrior.h>
#include <maths/common/CPoissonMeanConjugate.h>
#include <maths/common/CPrior.h>
#include <memory>
#include <string>
#include <typeinfo>
namespace ml {
namespace maths {
namespace common {
namespace {
// There needs to be one constant here per sub-class of CPrior.
// DO NOT change the existing tags if new sub-classes are added.
const core::TPersistenceTag GAMMA_TAG("a", "gamma");
const core::TPersistenceTag LOG_NORMAL_TAG("b", "log_normal");
const core::TPersistenceTag MULTIMODAL_TAG("c", "multimodal");
const core::TPersistenceTag NORMAL_TAG("d", "normal");
const core::TPersistenceTag ONE_OF_N_TAG("e", "one-of-n");
const core::TPersistenceTag POISSON_TAG("f", "poisson");
const core::TPersistenceTag MULTINOMIAL_TAG("g", "multimonial");
const core::TPersistenceTag CONSTANT_TAG("h", "constant");
const std::string EMPTY_STRING;
//! Implements restore for std::shared_ptr.
template<typename T>
void doRestore(std::shared_ptr<CPrior>& ptr, core::CStateRestoreTraverser& traverser) {
ptr = std::make_shared<T>(traverser);
}
//! Implements restore for std::unique_ptr.
template<typename T>
void doRestore(std::unique_ptr<CPrior>& ptr, core::CStateRestoreTraverser& traverser) {
ptr = std::make_unique<T>(traverser);
}
//! Implements restore for std::shared_ptr.
template<typename T>
void doRestore(const SDistributionRestoreParams& params,
std::shared_ptr<CPrior>& ptr,
core::CStateRestoreTraverser& traverser) {
ptr = std::make_shared<T>(params, traverser);
}
//! Implements restore for std::unique_ptr.
template<typename T>
void doRestore(const SDistributionRestoreParams& params,
std::unique_ptr<CPrior>& ptr,
core::CStateRestoreTraverser& traverser) {
ptr = std::make_unique<T>(params, traverser);
}
//! Implements restore into the supplied pointer.
template<typename PTR>
bool restore(const SDistributionRestoreParams& params,
PTR& ptr,
core::CStateRestoreTraverser& traverser) {
std::size_t numResults{0};
do {
const std::string& name = traverser.name();
if (name == CONSTANT_TAG) {
doRestore<CConstantPrior>(ptr, traverser);
++numResults;
} else if (name == GAMMA_TAG) {
doRestore<CGammaRateConjugate>(params, ptr, traverser);
++numResults;
} else if (name == LOG_NORMAL_TAG) {
doRestore<CLogNormalMeanPrecConjugate>(params, ptr, traverser);
++numResults;
} else if (name == MULTIMODAL_TAG) {
doRestore<CMultimodalPrior>(params, ptr, traverser);
++numResults;
} else if (name == MULTINOMIAL_TAG) {
doRestore<CMultinomialConjugate>(params, ptr, traverser);
++numResults;
} else if (name == NORMAL_TAG) {
doRestore<CNormalMeanPrecConjugate>(params, ptr, traverser);
++numResults;
} else if (name == ONE_OF_N_TAG) {
doRestore<COneOfNPrior>(params, ptr, traverser);
++numResults;
} else if (name == POISSON_TAG) {
doRestore<CPoissonMeanConjugate>(params, ptr, traverser);
++numResults;
} else {
// Due to the way we divide large state into multiple chunks
// this is not necessarily a problem - the unexpected element may be
// marking the start of a new chunk
LOG_WARN(<< "No prior distribution corresponds to node name "
<< traverser.name());
}
} while (traverser.next());
if (numResults != 1) {
LOG_ERROR(<< "Expected 1 (got " << numResults << ") prior model tags");
ptr.reset();
return false;
}
return true;
}
}
bool CPriorStateSerialiser::operator()(const SDistributionRestoreParams& params,
TPriorUPtr& ptr,
core::CStateRestoreTraverser& traverser) const {
return restore(params, ptr, traverser);
}
bool CPriorStateSerialiser::operator()(const SDistributionRestoreParams& params,
TPriorSPtr& ptr,
core::CStateRestoreTraverser& traverser) const {
return restore(params, ptr, traverser);
}
void CPriorStateSerialiser::operator()(const CPrior& prior,
core::CStatePersistInserter& inserter) const {
core::TPersistenceTag tagName;
if (dynamic_cast<const CConstantPrior*>(&prior) != nullptr) {
tagName = CONSTANT_TAG;
} else if (dynamic_cast<const CGammaRateConjugate*>(&prior) != nullptr) {
tagName = GAMMA_TAG;
} else if (dynamic_cast<const CLogNormalMeanPrecConjugate*>(&prior) != nullptr) {
tagName = LOG_NORMAL_TAG;
} else if (dynamic_cast<const CMultimodalPrior*>(&prior) != nullptr) {
tagName = MULTIMODAL_TAG;
} else if (dynamic_cast<const CMultinomialConjugate*>(&prior) != nullptr) {
tagName = MULTINOMIAL_TAG;
} else if (dynamic_cast<const CNormalMeanPrecConjugate*>(&prior) != nullptr) {
tagName = NORMAL_TAG;
} else if (dynamic_cast<const COneOfNPrior*>(&prior) != nullptr) {
tagName = ONE_OF_N_TAG;
} else if (dynamic_cast<const CPoissonMeanConjugate*>(&prior) != nullptr) {
tagName = POISSON_TAG;
} else {
LOG_ERROR(<< "Prior distribution with type '" << typeid(prior).name()
<< "' has no defined field name");
return;
}
inserter.insertLevel(tagName, std::bind(&CPrior::acceptPersistInserter,
&prior, std::placeholders::_1));
}
bool CPriorStateSerialiser::operator()(const SDistributionRestoreParams& params,
TMultivariatePriorPtr& ptr,
core::CStateRestoreTraverser& traverser) const {
std::size_t numResults = 0;
do {
const std::string& name = traverser.name();
if (name == CMultivariatePrior::CONSTANT_TAG) {
std::size_t dimension;
if (core::CStringUtils::stringToType(
name.substr(CMultivariatePrior::CONSTANT_TAG.length()), dimension) == false) {
LOG_ERROR(<< "Bad dimension encoded in " << name);
return false;
}
ptr.reset(new CMultivariateConstantPrior(dimension, traverser));
++numResults;
} else if (name.find(CMultivariatePrior::MULTIMODAL_TAG) != std::string::npos) {
std::size_t dimension;
if (core::CStringUtils::stringToType(
name.substr(CMultivariatePrior::MULTIMODAL_TAG.length()),
dimension) == false) {
LOG_ERROR(<< "Bad dimension encoded in " << name);
return false;
}
CMultivariateMultimodalPriorFactory::restore(dimension, params, ptr, traverser);
++numResults;
} else if (name.find(CMultivariatePrior::NORMAL_TAG) != std::string::npos) {
std::size_t dimension;
if (core::CStringUtils::stringToType(
name.substr(CMultivariatePrior::NORMAL_TAG.length()), dimension) == false) {
LOG_ERROR(<< "Bad dimension encoded in " << name);
return false;
}
CMultivariateNormalConjugateFactory::restore(dimension, params, ptr, traverser);
++numResults;
} else if (name.find(CMultivariatePrior::ONE_OF_N_TAG) != std::string::npos) {
std::size_t dimension;
if (core::CStringUtils::stringToType(
name.substr(CMultivariatePrior::ONE_OF_N_TAG.length()), dimension) == false) {
LOG_ERROR(<< "Bad dimension encoded in " << name);
return false;
}
CMultivariateOneOfNPriorFactory::restore(dimension, params, ptr, traverser);
++numResults;
} else {
// Due to the way we divide large state into multiple chunks
// this is not necessarily a problem - the unexpected element may be
// marking the start of a new chunk
LOG_WARN(<< "No prior distribution corresponds to node name "
<< traverser.name());
}
} while (traverser.next());
if (numResults != 1) {
LOG_ERROR(<< "Expected 1 (got " << numResults << ") prior model tags");
ptr.reset();
return false;
}
return true;
}
void CPriorStateSerialiser::operator()(const CMultivariatePrior& prior,
core::CStatePersistInserter& inserter) const {
inserter.insertLevel(prior.persistenceTag(),
std::bind(&CMultivariatePrior::acceptPersistInserter,
&prior, std::placeholders::_1));
}
}
}
}