lib/api/CPerPartitionCategoryIdMapper.cc (59 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <api/CPerPartitionCategoryIdMapper.h>
#include <core/CLogger.h>
#include <core/CStatePersistInserter.h>
#include <core/CStateRestoreTraverser.h>
namespace {
const std::string GLOBAL_ID_TAG{"a"};
}
namespace ml {
namespace api {
CPerPartitionCategoryIdMapper::CPerPartitionCategoryIdMapper(std::string categorizerKey,
TNextGlobalIdSupplier nextGlobalIdSupplier)
: m_CategorizerKey{std::move(categorizerKey)}, m_NextGlobalIdSupplier{
std::move(nextGlobalIdSupplier)} {
}
CGlobalCategoryId CPerPartitionCategoryIdMapper::map(model::CLocalCategoryId localCategoryId) {
if (localCategoryId.isValid() == false) {
return CGlobalCategoryId{localCategoryId.id()};
}
std::size_t index{localCategoryId.index()};
if (index > m_Mappings.size()) {
LOG_ERROR(<< "Bad category mappings: " << (index - m_Mappings.size())
<< " local to global category ID mappings missing for partition "
<< m_CategorizerKey);
m_Mappings.resize(index);
}
if (index == m_Mappings.size()) {
m_Mappings.emplace_back(m_NextGlobalIdSupplier(), m_CategorizerKey, localCategoryId);
}
return m_Mappings[index];
}
const std::string& CPerPartitionCategoryIdMapper::categorizerKey() const {
return m_CategorizerKey;
}
CCategoryIdMapper::TCategoryIdMapperPtr CPerPartitionCategoryIdMapper::clone() const {
return std::make_shared<CPerPartitionCategoryIdMapper>(*this);
}
void CPerPartitionCategoryIdMapper::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
for (const auto& globalCategoryId : m_Mappings) {
inserter.insertValue(GLOBAL_ID_TAG, globalCategoryId.globalId());
}
}
bool CPerPartitionCategoryIdMapper::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
m_Mappings.clear();
do {
const std::string& name{traverser.name()};
if (name == GLOBAL_ID_TAG) {
int globalId{model::CLocalCategoryId::SOFT_CATEGORIZATION_FAILURE_ERROR};
if (core::CStringUtils::stringToType(traverser.value(), globalId) == false) {
LOG_ERROR(<< "Invalid global ID in " << traverser.value());
return false;
}
m_Mappings.emplace_back(globalId, m_CategorizerKey,
model::CLocalCategoryId{m_Mappings.size()});
}
} while (traverser.next());
return true;
}
}
}