lib/maths/common/CMultivariateNormalConjugateFactory.cc (61 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <maths/common/CMultivariateNormalConjugateFactory.h> #include <maths/common/CMultivariateNormalConjugate.h> namespace ml { namespace maths { namespace common { namespace { template<std::size_t N> class CFactory { public: static CMultivariateNormalConjugate<N>* make(const SDistributionRestoreParams& params, core::CStateRestoreTraverser& traverser) { return new CMultivariateNormalConjugate<N>(params, traverser); } static CMultivariateNormalConjugate<N>* make(maths_t::EDataType dataType, double decayRate) { return CMultivariateNormalConjugate<N>::nonInformativePrior(dataType, decayRate) .clone(); } }; } #define CREATE_PRIOR(N) \ switch (N) { \ case 2: \ ptr.reset(CFactory<2>::make(FACTORY_ARGS)); \ break; \ case 3: \ ptr.reset(CFactory<3>::make(FACTORY_ARGS)); \ break; \ case 4: \ ptr.reset(CFactory<4>::make(FACTORY_ARGS)); \ break; \ case 5: \ ptr.reset(CFactory<5>::make(FACTORY_ARGS)); \ break; \ default: \ LOG_ERROR(<< "Unsupported dimension " << N); \ break; \ } CMultivariateNormalConjugateFactory::TPriorPtr CMultivariateNormalConjugateFactory::nonInformative(std::size_t dimension, maths_t::EDataType dataType, double decayRate) { TPriorPtr ptr; #define FACTORY_ARGS dataType, decayRate CREATE_PRIOR(dimension); #undef FACTORY_ARGS return ptr; } bool CMultivariateNormalConjugateFactory::restore(std::size_t dimension, const SDistributionRestoreParams& params, TPriorPtr& ptr, core::CStateRestoreTraverser& traverser) { ptr.reset(); #define FACTORY_ARGS params, traverser CREATE_PRIOR(dimension); #undef FACTORY_ARGS return ptr != nullptr; } #undef CREATE_PRIOR } } }