include/maths/common/CXMeansOnlineFactory.h (69 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#ifndef INCLUDED_ml_maths_common_CXMeansOnlineFactory_h
#define INCLUDED_ml_maths_common_CXMeansOnlineFactory_h
#include <maths/common/CClusterer.h>
#include <maths/common/CLinearAlgebraFwd.h>
#include <maths/common/ImportExport.h>
#include <maths/common/MathsTypes.h>
#include <boost/static_assert.hpp>
#include <cstddef>
namespace ml {
namespace core {
class CStateRestoreTraverser;
}
namespace maths {
namespace common {
template<typename POINT>
class CClusterer;
struct SDistributionRestoreParams;
namespace xmeans_online_factory_detail {
template<typename T, std::size_t N>
class CFactory {};
#define XMEANS_FACTORY(T, N) \
template<> \
class MATHS_COMMON_EXPORT CFactory<T, N> { \
public: \
static CClusterer<CVectorNx1<T, N>>* make(maths_t::EDataType dataType, \
maths_t::EClusterWeightCalc weightCalc, \
double decayRate, \
double minimumClusterFraction, \
double minimumClusterCount, \
double minimumCategoryCount); \
static CClusterer<CVectorNx1<T, N>>* \
restore(const SDistributionRestoreParams& params, \
const CClustererTypes::TSplitFunc& splitFunc, \
const CClustererTypes::TMergeFunc& mergeFunc, \
core::CStateRestoreTraverser& traverser); \
}
XMEANS_FACTORY(CFloatStorage, 2);
XMEANS_FACTORY(CFloatStorage, 3);
XMEANS_FACTORY(CFloatStorage, 4);
XMEANS_FACTORY(CFloatStorage, 5);
#undef XMEANS_FACTORY
}
//! \brief Factory for multivariate x-means online clusterers.
class MATHS_COMMON_EXPORT CXMeansOnlineFactory {
public:
//! Create a new x-means clusterer.
//!
//! \param[in] dataType The type of data which will be clustered.
//! \param[in] weightCalc The style of the cluster weight calculation
//! (see maths_t::EClusterWeightCalc for details).
//! \param[in] decayRate Controls the rate at which information is
//! lost from the clusters.
//! \param[in] minimumClusterFraction The minimum fractional count
//! of points in a cluster.
//! \param[in] minimumClusterCount The minimum count of points in a
//! cluster.
template<typename T, std::size_t N>
static inline CClusterer<CVectorNx1<T, N>>* make(maths_t::EDataType dataType,
maths_t::EClusterWeightCalc weightCalc,
double decayRate,
double minimumClusterFraction,
double minimumClusterCount,
double minimumCategoryCount) {
return xmeans_online_factory_detail::CFactory<T, N>::make(
dataType, weightCalc, decayRate, minimumClusterFraction,
minimumClusterCount, minimumCategoryCount);
}
//! Construct by traversing a state document.
template<typename T, std::size_t N>
static inline CClusterer<CVectorNx1<T, N>>*
restore(const SDistributionRestoreParams& params,
const CClustererTypes::TSplitFunc& splitFunc,
const CClustererTypes::TMergeFunc& mergeFunc,
core::CStateRestoreTraverser& traverser) {
return xmeans_online_factory_detail::CFactory<T, N>::restore(
params, splitFunc, mergeFunc, traverser);
}
};
}
}
}
#endif // INCLUDED_ml_maths_common_CXMeansOnlineFactory_h