lib/core/CStateMachine.cc (189 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <core/CStateMachine.h>
#include <core/CFastMutex.h>
#include <core/CHashing.h>
#include <core/CLogger.h>
#include <core/CScopedFastLock.h>
#include <core/CStatePersistInserter.h>
#include <core/CStateRestoreTraverser.h>
#include <core/CoreTypes.h>
#include <core/RestoreMacros.h>
#include <core/UnwrapRef.h>
#include <sstream>
namespace ml {
namespace core {
namespace {
// CStateMachine
//const std::string MACHINE_TAG("a"); No longer used
const core::TPersistenceTag STATE_TAG("b", "state");
// CStateMachine::SMachine
const std::string ALPHABET_TAG("a");
const std::string STATES_TAG("b");
const std::string TRANSITION_FUNCTION_TAG("c");
std::size_t BAD_MACHINE = std::numeric_limits<std::size_t>::max();
CFastMutex mutex;
}
void CStateMachine::expectedNumberMachines(std::size_t number) {
CScopedFastLock lock(mutex);
ms_Machines.capacity(number);
}
CStateMachine CStateMachine::create(const TStrVec& alphabet,
const TStrVec& states,
const TSizeVecVec& transitionFunction,
std::size_t state) {
// Validate that the alphabet, states, transition function,
// and initial state are consistent.
CStateMachine result;
if (state >= states.size()) {
LOG_ERROR(<< "Invalid initial state: " << state);
return result;
}
if (alphabet.empty() || alphabet.size() != transitionFunction.size()) {
LOG_ERROR(<< "Bad alphabet: " << alphabet);
return result;
}
for (const auto& function : transitionFunction) {
if (states.size() != function.size()) {
LOG_ERROR(<< "Bad transition function row: " << function);
return result;
}
}
// We use the standard double lock pattern with an atomic size to
// indicate that a machine is ready to use. Because we are storing
// the machines in a custom deque container a concurrent push_back
// doesn't invalidate access to any other existing machine.
SLookupMachine machine(alphabet, states, transitionFunction);
std::size_t size = ms_Machines.size();
std::size_t m = find(0, size, machine);
if (m == size || machine != ms_Machines[m]) {
CScopedFastLock lock(mutex);
m = find(0, ms_Machines.size(), machine);
if (m == ms_Machines.size()) {
ms_Machines.push_back(SMachine(alphabet, states, transitionFunction));
}
}
result.m_Machine = m;
result.m_State = state;
return result;
}
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser,
const TSizeSizeMap& mapping) {
do {
const std::string& name = traverser.name();
RESTORE_BUILT_IN(STATE_TAG, m_State)
} while (traverser.next());
if (mapping.size() > 0) {
auto mapped = mapping.find(m_State);
if (mapped != mapping.end()) {
m_State = mapped->second;
} else {
LOG_ERROR(<< "Bad mapping '" << mapping << "' state = " << m_State);
return false;
}
}
return true;
}
void CStateMachine::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
inserter.insertValue(STATE_TAG, m_State);
}
bool CStateMachine::bad() const {
return m_Machine == BAD_MACHINE;
}
bool CStateMachine::apply(std::size_t symbol) {
const TSizeVecVec& table = ms_Machines[m_Machine].s_TransitionFunction;
if (symbol >= table.size()) {
LOG_ERROR(<< "Bad symbol " << symbol << " not in alphabet [" << table.size() << "]");
return false;
}
if (m_State >= table[symbol].size()) {
LOG_ERROR(<< "Bad state " << m_State << " not in states ["
<< table[symbol].size() << "]");
return false;
}
m_State = table[symbol][m_State];
return true;
}
std::size_t CStateMachine::state() const {
return m_State;
}
std::string CStateMachine::printState(std::size_t state) const {
if (state >= ms_Machines[m_Machine].s_States.size()) {
return "State Not Found";
}
return ms_Machines[m_Machine].s_States[state];
}
std::string CStateMachine::printSymbol(std::size_t symbol) const {
if (symbol >= ms_Machines[m_Machine].s_Alphabet.size()) {
return "Symbol Not Found";
}
return ms_Machines[m_Machine].s_Alphabet[symbol];
}
std::uint64_t CStateMachine::checksum() const {
return CHashing::hashCombine(static_cast<std::uint64_t>(m_Machine),
static_cast<std::uint64_t>(m_State));
}
std::size_t CStateMachine::numberMachines() {
CScopedFastLock lock(mutex);
return ms_Machines.size();
}
void CStateMachine::clear() {
CScopedFastLock lock(mutex);
ms_Machines.clear();
}
std::size_t CStateMachine::find(std::size_t begin, std::size_t end, const SLookupMachine& machine) {
for (std::size_t i = begin; i < end; ++i) {
if (machine == ms_Machines[i]) {
return i;
}
}
return end;
}
CStateMachine::CStateMachine() : m_Machine(BAD_MACHINE), m_State(0) {
}
CStateMachine::SMachine::SMachine(const TStrVec& alphabet,
const TStrVec& states,
const TSizeVecVec& transitionFunction)
: s_Alphabet(alphabet), s_States(states), s_TransitionFunction(transitionFunction) {
}
CStateMachine::SMachine::SMachine(const SMachine& other)
: s_Alphabet(other.s_Alphabet), s_States(other.s_States),
s_TransitionFunction(other.s_TransitionFunction) {
}
CStateMachine::SLookupMachine::SLookupMachine(const TStrVec& alphabet,
const TStrVec& states,
const TSizeVecVec& transitionFunction)
: s_Alphabet(alphabet), s_States(states), s_TransitionFunction(transitionFunction) {
}
bool CStateMachine::SLookupMachine::operator==(const SMachine& rhs) const {
return unwrap_ref(s_TransitionFunction) == rhs.s_TransitionFunction &&
unwrap_ref(s_Alphabet) == rhs.s_Alphabet &&
unwrap_ref(s_States) == rhs.s_States;
}
CStateMachine::CMachineDeque::CMachineDeque()
: m_Capacity(DEFAULT_CAPACITY), m_NumberMachines(0) {
m_Machines.push_back(TMachineVec());
m_Machines.back().reserve(m_Capacity);
}
void CStateMachine::CMachineDeque::capacity(std::size_t capacity) {
m_Capacity = capacity;
}
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos_) const {
std::size_t pos{pos_};
for (const auto& machines : m_Machines) {
if (pos < machines.size()) {
return machines[pos];
}
pos -= machines.size();
}
LOG_ABORT(<< "Invalid index '" << pos_ << "'");
}
std::size_t CStateMachine::CMachineDeque::size() const {
return m_NumberMachines.load(std::memory_order_acquire);
}
void CStateMachine::CMachineDeque::push_back(const SMachine& machine) {
if (m_Machines.back().size() == m_Capacity) {
m_Machines.push_back(TMachineVec());
m_Machines.back().reserve(m_Capacity);
}
m_Machines.back().push_back(machine);
m_NumberMachines.store(this->size() + 1, std::memory_order_release);
}
void CStateMachine::CMachineDeque::clear() {
m_NumberMachines.store(0);
m_Machines.clear();
m_Machines.push_back(TMachineVec());
m_Machines.back().reserve(m_Capacity);
}
CStateMachine::CMachineDeque CStateMachine::ms_Machines;
}
}