lib/maths/common/CPRNG.cc (188 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <maths/common/CPRNG.h> #include <core/CPersistUtils.h> #include <core/CStringUtils.h> #include <maths/common/CChecksum.h> #include <algorithm> namespace ml { namespace maths { namespace common { namespace { namespace detail { //! Discard a sequence of \p n random numbers. template<typename PRNG> inline void discard(std::uint64_t n, PRNG& rng) { for (/**/; n > 0; --n) { rng(); } } //! Rotate about the \p k'th bit. std::uint64_t rotl(const std::uint64_t x, int k) { return (x << k) | (x >> (64 - k)); } } } CPRNG::CSplitMix64::CSplitMix64() : m_X(0) { this->seed(); } CPRNG::CSplitMix64::CSplitMix64(result_type seed) : m_X(0) { this->seed(seed); } bool CPRNG::CSplitMix64::operator==(CSplitMix64 other) const { return m_X == other.m_X; } void CPRNG::CSplitMix64::seed() { m_X = 0; } void CPRNG::CSplitMix64::seed(result_type seed) { m_X = seed; } CPRNG::CSplitMix64::result_type CPRNG::CSplitMix64::operator()() { result_type x = (m_X += A); x = (x ^ (x >> 30)) * B; x = (x ^ (x >> 27)) * C; return x ^ (x >> 31); } void CPRNG::CSplitMix64::discard(result_type n) { detail::discard(n, *this); } std::string CPRNG::CSplitMix64::toString() const { return core::CStringUtils::typeToString(m_X); } bool CPRNG::CSplitMix64::fromString(const std::string& state) { return core::CStringUtils::stringToType(state, m_X); } std::uint64_t CPRNG::CSplitMix64::checksum(std::uint64_t seed) const { return CChecksum::calculate(seed, m_X); } const CPRNG::CSplitMix64::result_type CPRNG::CSplitMix64::A(0x9E3779B97F4A7C15); const CPRNG::CSplitMix64::result_type CPRNG::CSplitMix64::B(0xBF58476D1CE4E5B9); const CPRNG::CSplitMix64::result_type CPRNG::CSplitMix64::C(0x94D049BB133111EB); CPRNG::CXorOShiro128Plus::CXorOShiro128Plus() { this->seed(); } CPRNG::CXorOShiro128Plus::CXorOShiro128Plus(result_type seed) { this->seed(seed); } bool CPRNG::CXorOShiro128Plus::operator==(const CXorOShiro128Plus& other) const { return std::equal(&m_X[0], &m_X[2], &other.m_X[0]); } void CPRNG::CXorOShiro128Plus::seed() { this->seed(0); } void CPRNG::CXorOShiro128Plus::seed(result_type seed) { CSplitMix64 seeds(seed); seeds.generate(&m_X[0], &m_X[2]); } CPRNG::CXorOShiro128Plus::result_type CPRNG::CXorOShiro128Plus::operator()() { result_type x0 = m_X[0]; result_type x1 = m_X[1]; result_type result = x0 + x1; x1 ^= x0; m_X[0] = detail::rotl(x0, 55) ^ x1 ^ (x1 << 14); m_X[1] = detail::rotl(x1, 36); return result; } void CPRNG::CXorOShiro128Plus::discard(result_type n) { detail::discard(n, *this); } void CPRNG::CXorOShiro128Plus::jump() { result_type x[2] = {0}; for (std::size_t i = 0; i < 2; ++i) { for (unsigned int b = 0; b < 64; ++b) { if (JUMP[i] & 1ULL << b) { x[0] ^= m_X[0]; x[1] ^= m_X[1]; } this->operator()(); } } m_X[0] = x[0]; m_X[1] = x[1]; } std::string CPRNG::CXorOShiro128Plus::toString() const { const result_type* begin = &m_X[0]; const result_type* end = &m_X[2]; return core::CPersistUtils::toString(begin, end); } bool CPRNG::CXorOShiro128Plus::fromString(const std::string& state) { return core::CPersistUtils::fromString(state, &m_X[0], &m_X[2]); } std::uint64_t CPRNG::CXorOShiro128Plus::checksum(std::uint64_t seed) const { return CChecksum::calculate(seed, m_X); } const CPRNG::CXorOShiro128Plus::result_type CPRNG::CXorOShiro128Plus::JUMP[] = { 0xbeac0467eba5facb, 0xd86b048b86aa9922}; CPRNG::CXorShift1024Mult::CXorShift1024Mult() : m_P(0) { this->seed(); } CPRNG::CXorShift1024Mult::CXorShift1024Mult(result_type seed) : m_P(0) { this->seed(seed); } bool CPRNG::CXorShift1024Mult::operator==(const CXorShift1024Mult& other) const { return m_P == other.m_P && std::equal(&m_X[0], &m_X[16], &other.m_X[0]); } void CPRNG::CXorShift1024Mult::seed() { this->seed(0); } void CPRNG::CXorShift1024Mult::seed(result_type seed) { CSplitMix64 seeds(seed); seeds.generate(&m_X[0], &m_X[16]); } CPRNG::CXorShift1024Mult::result_type CPRNG::CXorShift1024Mult::operator()() { result_type x0 = m_X[m_P]; m_P = (m_P + 1) & 15; result_type x1 = m_X[m_P]; x1 ^= x1 << 31; m_X[m_P] = x1 ^ x0 ^ (x1 >> 11) ^ (x0 >> 30); return m_X[m_P] * A; } void CPRNG::CXorShift1024Mult::discard(result_type n) { detail::discard(n, *this); } void CPRNG::CXorShift1024Mult::jump() { result_type t[16] = {0}; for (std::size_t i = 0; i < 16; ++i) { for (unsigned int b = 0; b < 64; ++b) { if (JUMP[i] & 1ULL << b) { for (int j = 0; j < 16; ++j) { t[j] ^= m_X[(j + m_P) & 15]; } } this->operator()(); } } for (int j = 0; j < 16; j++) { m_X[(j + m_P) & 15] = t[j]; } } std::string CPRNG::CXorShift1024Mult::toString() const { const result_type* begin = &m_X[0]; const result_type* end = &m_X[16]; return core::CPersistUtils::toString(begin, end) + core::CPersistUtils::PAIR_DELIMITER + core::CStringUtils::typeToString(m_P); } bool CPRNG::CXorShift1024Mult::fromString(std::string state) { std::size_t delimPos = state.find(core::CPersistUtils::PAIR_DELIMITER); if (delimPos == std::string::npos) { return false; } std::string p; p.assign(state, delimPos + 1, state.length() - delimPos); if (!core::CStringUtils::stringToType(p, m_P)) { return false; } state.resize(delimPos); return core::CPersistUtils::fromString(state, &m_X[0], &m_X[16]); } std::uint64_t CPRNG::CXorShift1024Mult::checksum(std::uint64_t seed) const { return CChecksum::calculate(seed, m_X); } const CPRNG::CXorShift1024Mult::result_type CPRNG::CXorShift1024Mult::A(1181783497276652981); const CPRNG::CXorShift1024Mult::result_type CPRNG::CXorShift1024Mult::JUMP[16] = { 0x84242f96eca9c41d, 0xa3c65b8776f96855, 0x5b34a39f070b5837, 0x4489affce4f31a1e, 0x2ffeeb0a48316f40, 0xdc2d9891fe68c022, 0x3659132bb12fea70, 0xaac17d8efa43cab8, 0xc4cb815590989b13, 0x5ee975283d71c93b, 0x691548c86c1bd540, 0x7910c41d10a1e6a5, 0x0b5fc64563b3e2a8, 0x047f7684e9fc949d, 0xb99181f2d8f685ca, 0x284600e3f30e38c3}; } } }