include/maths/common/COrderings.h (245 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#ifndef INCLUDED_ml_maths_common_COrderings_h
#define INCLUDED_ml_maths_common_COrderings_h
#include <core/CNonInstantiatable.h>
#include <core/UnwrapRef.h>
#include <maths/common/ImportExport.h>
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
namespace ml {
namespace maths {
namespace common {
//! \brief A collection of useful functionality to order collections
//! of objects.
//!
//! DESCRIPTION:\n
//! This implements some generic commonly occurring ordering functionality.
//! In particular,
//! -# Assorted comparison objects which handle derefencing optional and pointer
//! types, and unwrapping reference wrapped types.
//! -# Lexicographical comparison for small collections of objects with distinct
//! types (std::lexicographical_compare only supports a single type).
//! -# Efficiently, O(N log(N)), simultaneously sorting equal length collections
//! supporting random access iteration using one the collections to to define
//! the order for all.
class COrderings : private core::CNonInstantiatable {
public:
//! \brief Orders two pointers or std::optional values such that non-null
//! are less than null values and otherwise compares using std::less<>.
struct MATHS_COMMON_EXPORT SDerefLess {
template<typename T>
inline bool operator()(const T* lhs, const T* rhs) const {
return less(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::optional<T>& lhs, const std::optional<T>& rhs) const {
return less(lhs, rhs);
}
template<typename T>
static inline bool less(const T* lhs, const T* rhs) {
bool lInitialized{lhs != nullptr};
bool rInitialized{rhs != nullptr};
return lInitialized && rInitialized
? s_Less(core::unwrap_ref(*lhs), core::unwrap_ref(*rhs))
: rInitialized < lInitialized;
}
template<typename T>
static inline bool less(const std::optional<T>& lhs, const std::optional<T>& rhs) {
bool lInitialized{lhs != std::nullopt};
bool rInitialized{rhs != std::nullopt};
return lInitialized && rInitialized
? s_Less(core::unwrap_ref(*lhs), core::unwrap_ref(*rhs))
: rInitialized < lInitialized;
}
static const std::less<> s_Less;
};
//! \brief Orders two pointers or std::optional types such that null are
//! greater than non-null values and otherwise compares using std::greater<>.
struct MATHS_COMMON_EXPORT SDerefGreater {
template<typename T>
inline bool operator()(const T* lhs, const T* rhs) const {
return greater(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::optional<T>& lhs, const std::optional<T>& rhs) const {
return greater(lhs, rhs);
}
template<typename T>
static inline bool greater(const T* lhs, const T* rhs) {
bool lInitialized{lhs != nullptr};
bool rInitialized{rhs != nullptr};
return lInitialized && rInitialized
? s_Greater(core::unwrap_ref(*lhs), core::unwrap_ref(*rhs))
: rInitialized > lInitialized;
}
template<typename T>
static inline bool
greater(const std::optional<T>& lhs, const std::optional<T>& rhs) {
bool lInitialized{lhs != std::nullopt};
bool rInitialized{rhs != std::nullopt};
return lInitialized && rInitialized
? s_Greater(core::unwrap_ref(*lhs), core::unwrap_ref(*rhs))
: rInitialized > lInitialized;
}
static const std::greater<> s_Greater;
};
//! \brief Orders two reference wrapped objects which are
//! comparable with std::less<>.
struct MATHS_COMMON_EXPORT SReferenceLess {
template<typename U, typename V>
inline bool operator()(const U& lhs, const V& rhs) const {
return less(lhs, rhs);
}
template<typename U, typename V>
static inline bool less(const U& lhs, const V& rhs) {
return s_Less(core::unwrap_ref(lhs), core::unwrap_ref(rhs));
}
static const std::less<> s_Less;
};
//! \brief Orders two reference wrapped objects which are
//! comparable with std::greater<>.
struct MATHS_COMMON_EXPORT SReferenceGreater {
template<typename U, typename V>
inline bool operator()(const U& lhs, const V& rhs) const {
return greater(lhs, rhs);
}
template<typename U, typename V>
static inline bool greater(const U& lhs, const V& rhs) {
return s_Greater(core::unwrap_ref(lhs), core::unwrap_ref(rhs));
}
static const std::greater<> s_Greater;
};
//! \brief Wrapper around various less than comparisons.
struct MATHS_COMMON_EXPORT SLess {
template<typename U, typename V, std::enable_if_t<!std::is_pointer_v<U> || !std::is_pointer_v<V>>* = nullptr>
inline bool operator()(const U& lhs, const V& rhs) const {
return SReferenceLess::less(lhs, rhs);
}
template<typename T>
inline bool operator()(const T* lhs, const T* rhs) const {
return SDerefLess::less(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::optional<T>& lhs, const std::optional<T>& rhs) const {
return SDerefLess::less(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::shared_ptr<T>& lhs,
const std::shared_ptr<T>& rhs) const {
return SDerefLess::less(lhs.get(), rhs.get());
}
template<typename T>
inline bool operator()(const std::unique_ptr<T>& lhs,
const std::unique_ptr<T>& rhs) const {
return SDerefLess::less(lhs.get(), rhs.get());
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return lexicographicalCompareWith(*this, // compare with SLess
lhs.first, lhs.second, rhs.first, rhs.second);
}
};
//! \brief Wrapper around various greater than comparisons.
struct MATHS_COMMON_EXPORT SGreater {
template<typename U, typename V, std::enable_if_t<!std::is_pointer_v<U> || !std::is_pointer_v<V>>* = nullptr>
inline bool operator()(const U& lhs, const V& rhs) const {
return SReferenceGreater::greater(lhs, rhs);
}
template<typename T>
inline bool operator()(const T* lhs, const T* rhs) const {
return SDerefGreater::greater(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::optional<T>& lhs, const std::optional<T>& rhs) const {
return SDerefGreater::greater(lhs, rhs);
}
template<typename T>
inline bool operator()(const std::shared_ptr<T>& lhs,
const std::shared_ptr<T>& rhs) const {
return SDerefGreater::greater(lhs.get(), rhs.get());
}
template<typename T>
inline bool operator()(const std::unique_ptr<T>& lhs,
const std::unique_ptr<T>& rhs) const {
return SDerefGreater::greater(lhs.get(), rhs.get());
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return lexicographicalCompareWith(*this, // Compare with SGreater
lhs.first, lhs.second, rhs.first, rhs.second);
}
};
//! \brief Partial ordering of std::pairs on smaller first element.
//!
//! \note That while this functionality can be implemented by boost
//! bind, since it overloads the comparison operators, the resulting
//! code is more than an order of magnitude slower than this version.
struct MATHS_COMMON_EXPORT SFirstLess {
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return s_Less(lhs.first, rhs.first);
}
template<typename U, typename V>
inline bool operator()(const U& lhs, const std::pair<U, V>& rhs) const {
return s_Less(lhs, rhs.first);
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const U& rhs) const {
return s_Less(lhs.first, rhs);
}
SLess s_Less;
};
//! \brief Partial ordering of std::pairs based on larger first element.
//!
//! \note That while this functionality can be implemented by bind
//! bind, since it overloads the comparison operators, the resulting
//! code is more than an order of magnitude slower than this version.
struct MATHS_COMMON_EXPORT SFirstGreater {
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return s_Greater(lhs.first, rhs.first);
}
template<typename U, typename V>
inline bool operator()(const U& lhs, const std::pair<U, V>& rhs) const {
return s_Greater(lhs, rhs.first);
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const U& rhs) const {
return s_Greater(lhs.first, rhs);
}
SGreater s_Greater;
};
//! \brief Partial ordering of pairs based on smaller second element.
//!
//! \note That while this functionality can be implemented by boost
//! bind, since it overloads the comparison operators, the resulting
//! code is more than an order of magnitude slower than this version.
struct MATHS_COMMON_EXPORT SSecondLess {
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return s_Less(lhs.second, rhs.second);
}
template<typename U, typename V>
inline bool operator()(const V& lhs, const std::pair<U, V>& rhs) const {
return s_Less(lhs, rhs.second);
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const V& rhs) const {
return s_Less(lhs.second, rhs);
}
SLess s_Less;
};
//! \brief Partial ordering of pairs based on larger second element.
//!
//! \note That while this functionality can be implemented by boost
//! bind, since it overloads the comparison operators, the resulting
//! code is more than an order of magnitude slower than this version.
struct MATHS_COMMON_EXPORT SSecondGreater {
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const std::pair<U, V>& rhs) const {
return s_Greater(lhs.second, rhs.second);
}
template<typename U, typename V>
inline bool operator()(const V& lhs, const std::pair<U, V>& rhs) const {
return s_Greater(lhs, rhs.second);
}
template<typename U, typename V>
inline bool operator()(const std::pair<U, V>& lhs, const V& rhs) const {
return s_Greater(lhs.second, rhs);
}
SGreater s_Greater;
};
private:
template<std::size_t I, typename PRED, typename... ARGS>
static bool lexicographicalCompareAt(const PRED& pred, const ARGS&... args) {
static_assert(sizeof...(ARGS) % 2 == 0, "The number of values to compare must be equal");
const auto& lhs = std::get<I>(std::forward_as_tuple(args...));
const auto& rhs = std::get<I + sizeof...(ARGS) / 2>(std::forward_as_tuple(args...));
if (pred(lhs, rhs)) {
return true;
}
if constexpr (sizeof...(args) > 2 * (I + 1)) {
return pred(rhs, lhs) == false &&
lexicographicalCompareAt<I + 1>(pred, args...);
}
return false;
}
public:
//! Equivalent to std::lexicographical_compare for mixed types.
//!
//! Uses \p pred for comparison which must be able to compare each type.
template<typename PRED, typename... ARGS>
static bool lexicographicalCompareWith(const PRED& pred, const ARGS&... args) {
return lexicographicalCompareAt<0>(pred, args...);
}
//! Calls lexicographicalCompareWith supplying std::less<> for comparison.
template<typename... ARGS>
static bool lexicographicalCompare(const ARGS&... args) {
return lexicographicalCompareWith(std::less<>{}, args...);
}
//! Simultaneously sort multiple vectors using \p pred order of \p keys.
//!
//! This simultaneously sorts a number of vectors based on ordering a
//! collection of keys. For examples, the following code:
//! \code{cpp}
//! std::vector<double> ids{3.1, 2.2, 0.5, 1.5};
//! std::vector<std::string> names{"a", "b", "c", "d"};
//! maths::common::COrderings::simultaneousSort(ids, names);
//! for (std::size_t i = 0; i < 4; ++i) {
//! std::cout << ids[i] << ' ' << names[i] << std::endl;
//! }
//! \endcode
//!
//! Will produce the following output:
//! <pre>
//! 0.5 c
//! 1.5 d
//! 2.2 b
//! 3.1 a
//! </pre>
//!
//! \note The complexity is O(N log(N)) where N is the length of the
//! containers.
//! \warning All containers must have the same length.
template<typename PRED, typename K, typename... V>
static bool simultaneousSortWith(const PRED& pred, K&& keys, V&&... values);
//! Overload for default operator< comparison.
template<typename K, typename... V>
static bool simultaneousSort(K&& keys, V&&... values) {
return simultaneousSortWith(std::less<>(), std::forward<K>(keys),
std::forward<V>(values)...);
}
};
}
}
}
#endif // INCLUDED_ml_maths_common_COrderings_h