lib/api/CDataFrameAnalyzer.cc (308 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#include <api/CDataFrameAnalyzer.h>
#include <core/CContainerPrinter.h>
#include <core/CDataFrame.h>
#include <core/CFloatStorage.h>
#include <core/CJsonOutputStreamWrapper.h>
#include <core/CLogger.h>
#include <core/CStopWatch.h>
#include <core/CVectorRange.h>
#include <maths/common/CBasicStatistics.h>
#include <maths/common/COrderings.h>
#include <maths/common/COrderingsSimultaneousSort.h>
#include <api/CDataFrameAnalysisInstrumentation.h>
#include <api/CDataFrameAnalysisSpecification.h>
#include <api/CDataSummarizationJsonWriter.h>
#include <api/CInferenceModelDefinition.h>
#include <api/CInferenceModelMetadata.h>
#include <algorithm>
#include <cmath>
#include <iterator>
#include <limits>
#include <memory>
#include <string>
#include <vector>
namespace ml {
namespace api {
namespace {
using TStrVec = std::vector<std::string>;
using TMeanVarAccumulator = maths::common::CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
// Control message types:
const char FINISHED_DATA_CONTROL_MESSAGE_FIELD_VALUE{'$'};
// Result types
const std::string ROW_RESULTS{"row_results"};
// Row result fields
const std::string CHECKSUM{"checksum"};
const std::string RESULTS{"results"};
}
CDataFrameAnalyzer::CDataFrameAnalyzer(TDataFrameAnalysisSpecificationUPtr analysisSpecification,
TDataFrameUPtrTemporaryDirectoryPtrPr frameAndDirectory,
TJsonOutputStreamWrapperUPtrSupplier resultsStreamSupplier)
: m_AnalysisSpecification{std::move(analysisSpecification)},
m_ResultsStreamSupplier{std::move(resultsStreamSupplier)} {
std::tie(m_DataFrame, m_DataFrameDirectory) = std::move(frameAndDirectory);
if (m_DataFrame == nullptr) {
HANDLE_FATAL(<< "Internal error: missing data frame. Please report this problem.");
}
}
CDataFrameAnalyzer::~CDataFrameAnalyzer() = default;
bool CDataFrameAnalyzer::usingControlMessages() const {
return m_ControlFieldIndex >= 0;
}
bool CDataFrameAnalyzer::handleRecord(const TStrVec& fieldNames, const TStrVec& fieldValues) {
// Control messages are signified by a dot in the field name. This supports:
// - using the last field for a control message,
// - missing.
//
// Note if the control message field name is missing the analysis must
// be triggered to run by calling run explicitly.
// Note that returning false from this function immediately causes us to stop
// processing the input stream. Therefore, any error logged in this context is
// emitted at most once.
if (m_AnalysisSpecification == nullptr) {
// Logging handled when the analysis specification is created.
return false;
}
if (this->readyToReceiveControlMessages() == false &&
this->prepareToReceiveControlMessages(fieldNames) == false) {
// Logging handled in functions.
return false;
}
if (this->sufficientFieldValues(fieldValues) == false) {
// Logging handled in sufficientFieldValues.
return false;
}
if (this->isControlMessage(fieldValues)) {
return this->handleControlMessage(fieldValues);
}
this->captureFieldNames(fieldNames);
this->addRowToDataFrame(fieldValues);
return true;
}
void CDataFrameAnalyzer::receivedAllRows() {
if (m_DataFrame != nullptr) {
m_DataFrame->finishWritingRows();
LOG_DEBUG(<< "Received " << m_DataFrame->numberRows() << " rows");
}
}
void CDataFrameAnalyzer::run() {
if (m_AnalysisSpecification == nullptr || m_DataFrame == nullptr) {
return;
}
if (m_AnalysisSpecification->validate(*m_DataFrame) == false) {
return;
}
LOG_DEBUG(<< "Running analysis...");
// We create the writer in run so that when it is finished destructors
// get called and the wrapped stream does its job to close the array.
auto* analysisRunner = m_AnalysisSpecification->runner();
if (analysisRunner != nullptr) {
// We currently use a stream factory because the results are wrapped in
// an array. This is managed by the CJsonOutputStreamWrapper constructor
// and destructor. We should probably migrate to NDJSON format at which
// point this would no longer be necessary.
auto outStream = m_ResultsStreamSupplier();
auto& instrumentation = analysisRunner->instrumentation();
CDataFrameAnalysisInstrumentation::CScopeSetOutputStream setStream{
instrumentation, *outStream};
instrumentation.updateMemoryUsage(
static_cast<std::int64_t>(m_DataFrame->memoryUsage()));
instrumentation.flush();
analysisRunner->run(*m_DataFrame);
core::CBoostJsonConcurrentLineWriter outputWriter{*outStream};
CDataFrameAnalysisInstrumentation::monitor(instrumentation, outputWriter);
analysisRunner->waitToFinish();
this->writeInferenceModel(*analysisRunner, outputWriter);
this->writeResultsOf(*analysisRunner, outputWriter);
this->writeInferenceModelMetadata(*analysisRunner, outputWriter);
this->writeDataSummarization(*analysisRunner, outputWriter);
}
}
const core::CDataFrame& CDataFrameAnalyzer::dataFrame() const {
if (m_DataFrame == nullptr) {
HANDLE_FATAL(<< "Internal error: missing data frame. Please report this problem.");
}
return *m_DataFrame;
}
bool CDataFrameAnalyzer::readyToReceiveControlMessages() const {
return m_ControlFieldIndex != FIELD_UNSET;
}
bool CDataFrameAnalyzer::prepareToReceiveControlMessages(const TStrVec& fieldNames) {
// If this is being called by the Java API we'll use the last two columns for
// special purposes:
// - penultimate contains a 32 bit hash of the document.
// - last contains the control message.
//
// These will both be called . to avoid collision with any real field name.
auto posDocHash = std::find(fieldNames.begin(), fieldNames.end(), CONTROL_MESSAGE_FIELD_NAME);
auto posControlMessage = posDocHash == fieldNames.end()
? fieldNames.end()
: std::find(posDocHash + 1, fieldNames.end(),
CONTROL_MESSAGE_FIELD_NAME);
if (posDocHash == fieldNames.end() && posControlMessage == fieldNames.end()) {
m_ControlFieldIndex = FIELD_MISSING;
m_BeginDataFieldValues = 0;
m_EndDataFieldValues = static_cast<std::ptrdiff_t>(fieldNames.size());
m_DocHashFieldIndex = FIELD_MISSING;
} else if (fieldNames.size() < 2 || posDocHash != fieldNames.end() - 2 ||
posControlMessage != fieldNames.end() - 1) {
HANDLE_FATAL(<< "Input error: expected exactly two special "
<< "fields in last two positions but got '" << fieldNames
<< "'. Please report this problem.");
return false;
} else {
m_ControlFieldIndex = posControlMessage - fieldNames.begin();
m_BeginDataFieldValues = 0;
m_EndDataFieldValues = posDocHash - fieldNames.begin();
m_DocHashFieldIndex = m_ControlFieldIndex - 1;
}
return true;
}
bool CDataFrameAnalyzer::isControlMessage(const TStrVec& fieldValues) const {
return m_ControlFieldIndex >= 0 && fieldValues[m_ControlFieldIndex].size() > 0;
}
bool CDataFrameAnalyzer::sufficientFieldValues(const TStrVec& fieldValues) const {
std::size_t expectedNumberFieldValues{m_AnalysisSpecification->numberColumns() +
(m_ControlFieldIndex >= 0 ? 2 : 0)};
if (fieldValues.size() != expectedNumberFieldValues) {
HANDLE_FATAL(<< "Input error: expected " << expectedNumberFieldValues << " field"
<< " values and got " << fieldValues << ". Please report this problem.");
return false;
}
return true;
}
bool CDataFrameAnalyzer::handleControlMessage(const TStrVec& fieldValues) {
LOG_TRACE(<< "Control message: '" << fieldValues[m_ControlFieldIndex] << "'");
bool unrecognised{false};
switch (fieldValues[m_ControlFieldIndex][0]) {
case ' ':
// Spaces are just used to fill the buffers and force prior messages
// through the system - we don't need to do anything else.
LOG_TRACE(<< "Received pad of length "
<< fieldValues[m_ControlFieldIndex].length());
return true;
case FINISHED_DATA_CONTROL_MESSAGE_FIELD_VALUE:
this->receivedAllRows();
this->run();
break;
default:
unrecognised = true;
break;
}
if (unrecognised || fieldValues[m_ControlFieldIndex].size() > 1) {
HANDLE_FATAL(<< "Input error: invalid control message value '"
<< fieldValues[m_ControlFieldIndex] << "'. Please report this problem.");
return false;
}
return true;
}
void CDataFrameAnalyzer::captureFieldNames(const TStrVec& fieldNames) {
if (m_DataFrame == nullptr || m_CapturedFieldNames) {
return;
}
TStrVec columnNames{fieldNames.begin() + m_BeginDataFieldValues,
fieldNames.begin() + m_EndDataFieldValues};
if (m_DataFrame->hasColumnNames()) {
this->initializeDataFrameColumnMap(std::move(columnNames));
this->validateCategoricalColumnsMatch();
} else {
m_DataFrame->columnNames(std::move(columnNames));
m_DataFrame->categoricalColumns(m_AnalysisSpecification->categoricalFieldNames());
}
m_CapturedFieldNames = true;
}
void CDataFrameAnalyzer::initializeDataFrameColumnMap(TStrVec columnNames) {
// We take the view that missing fields are not fatal, since they may
// not be available for the new set, but extra fields are likely to
// indicate user error.
TPtrdiffVec positions(columnNames.size());
std::iota(positions.begin(), positions.end(), 0);
maths::common::COrderings::simultaneousSort(columnNames, positions);
TStrVec originalColumnNames{m_DataFrame->columnNames()};
std::sort(originalColumnNames.begin(), originalColumnNames.end());
TStrVec extraColumnNames;
std::set_difference(columnNames.begin(), columnNames.end(),
originalColumnNames.begin(), originalColumnNames.end(),
std::back_inserter(extraColumnNames));
if (extraColumnNames.empty() == false) {
HANDLE_FATAL(<< "Input error: supplying additional columns "
<< extraColumnNames << ".");
}
m_DataFrameColumnMap = std::make_unique<TPtrdiffVec>();
m_DataFrameColumnMap->reserve(columnNames.size());
for (const auto& name : m_DataFrame->columnNames()) {
auto i = std::lower_bound(columnNames.begin(), columnNames.end(), name);
if (i == columnNames.end() || *i != name) {
LOG_WARN(<< "Missing column '" << name << "'.");
m_DataFrameColumnMap->push_back(FIELD_MISSING);
} else {
m_DataFrameColumnMap->push_back(positions[i - columnNames.begin()]);
}
}
LOG_TRACE(<< "mapping = " << *m_DataFrameColumnMap);
}
void CDataFrameAnalyzer::validateCategoricalColumnsMatch() const {
TStrVec originalCategoricalColumns;
originalCategoricalColumns.reserve(m_DataFrame->columnNames().size());
for (std::size_t i = 0; i < m_DataFrame->columnNames().size(); ++i) {
if (m_DataFrame->columnIsCategorical()[i]) {
originalCategoricalColumns.push_back(m_DataFrame->columnNames()[i]);
}
}
TStrVec categoricalColumns{m_AnalysisSpecification->categoricalFieldNames()};
std::sort(originalCategoricalColumns.begin(), originalCategoricalColumns.end());
std::sort(categoricalColumns.begin(), categoricalColumns.end());
if (categoricalColumns != originalCategoricalColumns) {
HANDLE_FATAL(<< "Input error: mismatch in categorical columns " << categoricalColumns
<< " doesn't match " << originalCategoricalColumns << ".");
}
}
void CDataFrameAnalyzer::addRowToDataFrame(const TStrVec& fieldValues) {
if (m_DataFrame == nullptr) {
return;
}
auto columnValues = core::make_range(fieldValues, m_BeginDataFieldValues,
m_EndDataFieldValues);
m_DataFrame->parseAndWriteRow(columnValues, m_DataFrameColumnMap.get(),
m_DocHashFieldIndex != FIELD_MISSING
? &fieldValues[m_DocHashFieldIndex]
: nullptr);
}
void CDataFrameAnalyzer::writeInferenceModel(const CDataFrameAnalysisRunner& analysis,
core::CBoostJsonConcurrentLineWriter& writer) const {
// Write the resulting model for inference.
auto modelDefinition = analysis.inferenceModelDefinition(
m_DataFrame->columnNames(), m_DataFrame->categoricalColumnValues());
if (modelDefinition != nullptr) {
auto modelDefinitionSizeInfo = modelDefinition->sizeInfo();
json::object sizeInfoObject{writer.makeObject()};
modelDefinitionSizeInfo->addToJsonDocument(sizeInfoObject, writer);
writer.onObjectBegin();
writer.onKey(modelDefinitionSizeInfo->typeString());
writer.write(sizeInfoObject);
writer.onObjectEnd();
modelDefinition->addCompressedToJsonStream(writer);
}
writer.flush();
}
void CDataFrameAnalyzer::writeInferenceModelMetadata(const CDataFrameAnalysisRunner& analysis,
core::CBoostJsonConcurrentLineWriter& writer) const {
// Write model meta information
auto modelMetadata = analysis.inferenceModelMetadata();
if (modelMetadata) {
writer.onObjectBegin();
writer.onKey(modelMetadata->typeString());
writer.onObjectBegin();
modelMetadata->write(writer);
writer.onObjectEnd();
writer.onObjectEnd();
}
writer.flush();
}
void CDataFrameAnalyzer::writeDataSummarization(const CDataFrameAnalysisRunner& analysis,
core::CBoostJsonConcurrentLineWriter& writer) const {
// Write training data summarization
auto dataSummarization = analysis.dataSummarization();
if (dataSummarization != nullptr) {
dataSummarization->addCompressedToJsonStream(writer);
}
writer.flush();
}
void CDataFrameAnalyzer::writeResultsOf(const CDataFrameAnalysisRunner& analysis,
core::CBoostJsonConcurrentLineWriter& writer) const {
// We write results single threaded because we need to write the rows to
// Java in the order they were written to the data_frame_analyzer so it
// can join the extra columns with the original data frame.
std::size_t numberThreads{1};
auto rowsToWriteMask = analysis.rowsToWriteMask(*m_DataFrame);
LOG_TRACE(<< "# rows to write = " << rowsToWriteMask.manhattan());
using TRowItr = core::CDataFrame::TRowItr;
m_DataFrame->readRows(numberThreads, 0, m_DataFrame->numberRows(),
[&](const TRowItr& beginRows, const TRowItr& endRows) {
for (auto row = beginRows; row != endRows; ++row) {
writer.onObjectBegin();
writer.onKey(ROW_RESULTS);
writer.onObjectBegin();
writer.onKey(CHECKSUM);
writer.onInt(row->docHash());
writer.onKey(RESULTS);
writer.onObjectBegin();
writer.onKey(m_AnalysisSpecification->resultsField());
analysis.writeOneRow(*m_DataFrame, *row, writer);
writer.onObjectEnd();
writer.onObjectEnd();
writer.onObjectEnd();
}
},
&rowsToWriteMask);
writer.flush();
}
const CDataFrameAnalysisRunner* CDataFrameAnalyzer::runner() const {
return m_AnalysisSpecification->runner();
}
const std::string CDataFrameAnalyzer::CONTROL_MESSAGE_FIELD_NAME{"."};
const std::ptrdiff_t CDataFrameAnalyzer::FIELD_UNSET{-2};
const std::ptrdiff_t CDataFrameAnalyzer::FIELD_MISSING{-1};
}
}