lib/maths/time_series/CTimeSeriesDecompositionStateSerialiser.cc (94 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <maths/time_series/CTimeSeriesDecompositionStateSerialiser.h>
#include <core/CLogger.h>
#include <core/CStatePersistInserter.h>
#include <core/CStateRestoreTraverser.h>
#include <maths/common/CRestoreParams.h>
#include <maths/time_series/CTimeSeriesDecomposition.h>
#include <maths/time_series/CTimeSeriesDecompositionStub.h>
#include <memory>
#include <string>
#include <typeinfo>
namespace ml {
namespace maths {
namespace time_series {
namespace {
// We use short field names to reduce the state size
// There needs to be one constant here per sub-class
// of CTimeSeriesDecompositionInterface.
// DO NOT change the existing tags if new sub-classes are added.
const core::TPersistenceTag TIME_SERIES_DECOMPOSITION_TAG("a", "time_series_decomposition");
const core::TPersistenceTag TIME_SERIES_DECOMPOSITION_STUB_TAG("b", "time_series_decomposition_stub");
const std::string EMPTY_STRING;
//! Implements restore for std::shared_ptr.
template<typename T>
void doRestore(std::shared_ptr<CTimeSeriesDecompositionInterface>& ptr) {
ptr = std::make_shared<T>();
}
//! Implements restore for std::unique_ptr.
template<typename T>
void doRestore(std::unique_ptr<CTimeSeriesDecompositionInterface>& ptr) {
ptr = std::make_unique<T>();
}
//! Implements restore for std::shared_ptr.
template<typename T>
void doRestore(const common::STimeSeriesDecompositionRestoreParams& params,
std::shared_ptr<CTimeSeriesDecompositionInterface>& ptr,
core::CStateRestoreTraverser& traverser) {
ptr = std::make_shared<T>(params, traverser);
}
//! Implements restore for std::unique_ptr.
template<typename T>
void doRestore(const common::STimeSeriesDecompositionRestoreParams& params,
std::unique_ptr<CTimeSeriesDecompositionInterface>& ptr,
core::CStateRestoreTraverser& traverser) {
ptr = std::make_unique<T>(params, traverser);
}
//! Implements restore into the supplied pointer.
template<typename PTR>
bool restore(const common::STimeSeriesDecompositionRestoreParams& params,
PTR& ptr,
core::CStateRestoreTraverser& traverser) {
std::size_t numResults{0};
do {
const std::string& name = traverser.name();
if (name == TIME_SERIES_DECOMPOSITION_TAG) {
doRestore<CTimeSeriesDecomposition>(params, ptr, traverser);
++numResults;
} else if (name == TIME_SERIES_DECOMPOSITION_STUB_TAG) {
doRestore<CTimeSeriesDecompositionStub>(ptr);
++numResults;
} else {
LOG_ERROR(<< "No decomposition corresponds to name " << traverser.name());
return false;
}
} while (traverser.next());
if (numResults != 1) {
LOG_ERROR(<< "Expected 1 (got " << numResults << ") decomposition tags");
ptr.reset();
return false;
}
return true;
}
}
bool CTimeSeriesDecompositionStateSerialiser::
operator()(const common::STimeSeriesDecompositionRestoreParams& params,
TDecompositionUPtr& ptr,
core::CStateRestoreTraverser& traverser) const {
return restore(params, ptr, traverser);
}
bool CTimeSeriesDecompositionStateSerialiser::
operator()(const common::STimeSeriesDecompositionRestoreParams& params,
TDecompositionSPtr& ptr,
core::CStateRestoreTraverser& traverser) const {
return restore(params, ptr, traverser);
}
void CTimeSeriesDecompositionStateSerialiser::
operator()(const CTimeSeriesDecompositionInterface& decomposition,
core::CStatePersistInserter& inserter) const {
if (dynamic_cast<const CTimeSeriesDecomposition*>(&decomposition) != nullptr) {
inserter.insertLevel(
TIME_SERIES_DECOMPOSITION_TAG,
std::bind(&CTimeSeriesDecomposition::acceptPersistInserter,
dynamic_cast<const CTimeSeriesDecomposition*>(&decomposition),
std::placeholders::_1));
} else if (dynamic_cast<const CTimeSeriesDecompositionStub*>(&decomposition) != nullptr) {
inserter.insertValue(TIME_SERIES_DECOMPOSITION_STUB_TAG, "");
} else {
LOG_ERROR(<< "Decomposition with type '" << typeid(decomposition).name()
<< "' has no defined name");
}
}
}
}
}