lib/api/CDataFrameAnalysisSpecification.cc (237 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <api/CDataFrameAnalysisSpecification.h> #include <core/CBoostJsonParser.h> #include <core/CDataFrame.h> #include <core/CLogger.h> #include <core/CStringUtils.h> #include <api/CDataFrameAnalysisConfigReader.h> #include <api/CDataFrameOutliersRunner.h> #include <api/CDataFrameTrainBoostedTreeClassifierRunner.h> #include <api/CDataFrameTrainBoostedTreeRegressionRunner.h> #include <api/CMemoryUsageEstimationResultJsonWriter.h> #include <boost/json.hpp> #include <iterator> #include <memory> namespace ml { namespace api { // These must be consistent with Java names. const std::string CDataFrameAnalysisSpecification::JOB_ID{"job_id"}; const std::string CDataFrameAnalysisSpecification::ROWS{"rows"}; const std::string CDataFrameAnalysisSpecification::COLS{"cols"}; const std::string CDataFrameAnalysisSpecification::MEMORY_LIMIT{"memory_limit"}; const std::string CDataFrameAnalysisSpecification::THREADS{"threads"}; const std::string CDataFrameAnalysisSpecification::TEMPORARY_DIRECTORY{"temp_dir"}; const std::string CDataFrameAnalysisSpecification::RESULTS_FIELD{"results_field"}; const std::string CDataFrameAnalysisSpecification::MISSING_FIELD_VALUE{"missing_field_value"}; const std::string CDataFrameAnalysisSpecification::CATEGORICAL_FIELD_NAMES{"categorical_fields"}; const std::string CDataFrameAnalysisSpecification::DISK_USAGE_ALLOWED{"disk_usage_allowed"}; const std::string CDataFrameAnalysisSpecification::ANALYSIS{"analysis"}; const std::string CDataFrameAnalysisSpecification::NAME{"name"}; const std::string CDataFrameAnalysisSpecification::PARAMETERS{"parameters"}; namespace { using TBoolVec = std::vector<bool>; using TRunnerFactoryUPtrVec = CDataFrameAnalysisSpecification::TRunnerFactoryUPtrVec; TRunnerFactoryUPtrVec analysisFactories() { TRunnerFactoryUPtrVec factories; factories.push_back(std::make_unique<CDataFrameOutliersRunnerFactory>()); factories.push_back( std::make_unique<CDataFrameTrainBoostedTreeRegressionRunnerFactory>()); factories.push_back( std::make_unique<CDataFrameTrainBoostedTreeClassifierRunnerFactory>()); // Add new analysis types here. return factories; } const std::string DEFAULT_RESULT_FIELD("ml"); const bool DEFAULT_DISK_USAGE_ALLOWED(false); const CDataFrameAnalysisConfigReader CONFIG_READER{[] { CDataFrameAnalysisConfigReader theReader; theReader.addParameter(CDataFrameAnalysisSpecification::JOB_ID, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::ROWS, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(CDataFrameAnalysisSpecification::COLS, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(CDataFrameAnalysisSpecification::MEMORY_LIMIT, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(CDataFrameAnalysisSpecification::THREADS, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(CDataFrameAnalysisSpecification::TEMPORARY_DIRECTORY, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::RESULTS_FIELD, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::MISSING_FIELD_VALUE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::CATEGORICAL_FIELD_NAMES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::DISK_USAGE_ALLOWED, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(CDataFrameAnalysisSpecification::ANALYSIS, CDataFrameAnalysisConfigReader::E_RequiredParameter); return theReader; }()}; const CDataFrameAnalysisConfigReader ANALYSIS_READER{[] { CDataFrameAnalysisConfigReader theReader; theReader.addParameter(CDataFrameAnalysisSpecification::NAME, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(CDataFrameAnalysisSpecification::PARAMETERS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; }()}; } CDataFrameAnalysisSpecification::CDataFrameAnalysisSpecification( const std::string& jsonSpecification, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory, TPersisterSupplier persisterSupplier, TRestoreSearcherSupplier restoreSearcherSupplier) : CDataFrameAnalysisSpecification{analysisFactories(), jsonSpecification, frameAndDirectory, std::move(persisterSupplier), std::move(restoreSearcherSupplier)} { } CDataFrameAnalysisSpecification::CDataFrameAnalysisSpecification( TRunnerFactoryUPtrVec runnerFactories, const std::string& jsonSpecification, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory, TPersisterSupplier persisterSupplier, TRestoreSearcherSupplier restoreSearcherSupplier) : m_RunnerFactories{std::move(runnerFactories)}, m_PersisterSupplier{std::move(persisterSupplier)}, m_RestoreSearcherSupplier{std::move(restoreSearcherSupplier)} { json::value specification; bool ok = core::CBoostJsonParser::parse(jsonSpecification, specification); if (ok == false) { HANDLE_FATAL(<< "Input error: failed to parse analysis specification '" << jsonSpecification << "Please report this problem."); } else { LOG_TRACE(<< "specification: " << jsonSpecification); auto parameters = CONFIG_READER.read(specification); for (const auto& name : {ROWS, COLS, MEMORY_LIMIT, THREADS}) { if (parameters[name].as<std::size_t>() == 0) { HANDLE_FATAL(<< "Input error: '" << name << "' must be non-zero"); } } m_NumberRows = parameters[ROWS].as<std::size_t>(); m_NumberColumns = parameters[COLS].as<std::size_t>(); m_MemoryLimit = parameters[MEMORY_LIMIT].as<std::size_t>(); m_NumberThreads = parameters[THREADS].as<std::size_t>(); m_TemporaryDirectory = parameters[TEMPORARY_DIRECTORY].fallback(std::string{}); m_JobId = parameters[JOB_ID].fallback(std::string{}); m_ResultsField = parameters[RESULTS_FIELD].fallback(DEFAULT_RESULT_FIELD); m_MissingFieldValue = parameters[MISSING_FIELD_VALUE].fallback( core::CDataFrame::DEFAULT_MISSING_STRING); m_CategoricalFieldNames = parameters[CATEGORICAL_FIELD_NAMES].fallback(TStrVec{}); m_DiskUsageAllowed = parameters[DISK_USAGE_ALLOWED].fallback(DEFAULT_DISK_USAGE_ALLOWED); double missing; if (m_MissingFieldValue != core::CDataFrame::DEFAULT_MISSING_STRING && core::CStringUtils::stringToTypeSilent(m_MissingFieldValue, missing)) { HANDLE_FATAL(<< "Input error: you can't use a number (" << missing << ") to denote a missing field value."); } if (m_DiskUsageAllowed && m_TemporaryDirectory.empty()) { HANDLE_FATAL(<< "Input error: temporary directory path should be explicitly set if disk" " usage is allowed! Please report this problem."); } const auto* jsonAnalysis = parameters[ANALYSIS].jsonObject(); if (jsonAnalysis != nullptr) { this->initializeRunner(*jsonAnalysis, frameAndDirectory); } } } std::size_t CDataFrameAnalysisSpecification::numberRows() const { return m_NumberRows; } std::size_t CDataFrameAnalysisSpecification::numberColumns() const { return m_NumberColumns; } std::size_t CDataFrameAnalysisSpecification::numberExtraColumns() const { return m_Runner != nullptr ? m_Runner->numberExtraColumns() : 0; } std::size_t CDataFrameAnalysisSpecification::memoryLimit() const { return m_MemoryLimit; } std::size_t CDataFrameAnalysisSpecification::numberThreads() const { return m_NumberThreads; } const std::string& CDataFrameAnalysisSpecification::resultsField() const { return m_ResultsField; } const std::string& CDataFrameAnalysisSpecification::jobId() const { return m_JobId; } const std::string& CDataFrameAnalysisSpecification::analysisName() const { return m_AnalysisName; } const std::string& CDataFrameAnalysisSpecification::missingFieldValue() const { return m_MissingFieldValue; } const CDataFrameAnalysisSpecification::TStrVec& CDataFrameAnalysisSpecification::categoricalFieldNames() const { return m_CategoricalFieldNames; } bool CDataFrameAnalysisSpecification::diskUsageAllowed() const { return m_DiskUsageAllowed; } const std::string& CDataFrameAnalysisSpecification::temporaryDirectory() const { return m_TemporaryDirectory; } bool CDataFrameAnalysisSpecification::validate(const core::CDataFrame& frame) const { // The main condition to care about is if the analysis might use more memory // than was budgeted for. There are circumstances in which rows are excluded // after the search filter is applied so this can't trap the case that the row // counts are not equal. if (frame.numberRows() > this->numberRows()) { HANDLE_FATAL(<< "Input error: expected no more than '" << this->numberRows() << "' rows but got '" << frame.numberRows() << "' rows" << ". Please report this problem."); return false; } // As with rows, we only care if the analysis might use more memory than was // budgeted for. if (frame.numberColumns() > this->numberColumns()) { HANDLE_FATAL(<< "Input error: expected '" << this->numberColumns() << "' columns but got '" << frame.numberRows() << "' columns" << ". Please report this problem."); return false; } if (frame.numberRows() == 0) { HANDLE_FATAL(<< "Input error: no data sent."); return false; } return m_Runner == nullptr || m_Runner->validate(frame); } CDataFrameAnalysisRunner* CDataFrameAnalysisSpecification::runner() { return m_Runner.get(); } void CDataFrameAnalysisSpecification::estimateMemoryUsage(CMemoryUsageEstimationResultJsonWriter& writer) const { if (m_Runner == nullptr) { HANDLE_FATAL(<< "Internal error: no runner available so can't estimate memory." << " Please report this problem."); return; } m_Runner->estimateMemoryUsage(writer); } void CDataFrameAnalysisSpecification::initializeRunner(const json::value& jsonAnalysis, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) { // We pass of the interpretation of the parameters object to the appropriate // analysis runner. LOG_TRACE(<< "jsonAnalysis: " << jsonAnalysis); auto analysis = ANALYSIS_READER.read(jsonAnalysis); m_AnalysisName = analysis[NAME].as<std::string>(); LOG_TRACE(<< "Parsed analysis with name '" << m_AnalysisName << "'"); for (const auto& factory : m_RunnerFactories) { if (m_AnalysisName == factory->name()) { const auto* parameters = analysis[PARAMETERS].jsonObject(); m_Runner = parameters != nullptr ? factory->make(*this, *parameters, frameAndDirectory) : factory->make(*this, frameAndDirectory); return; } } HANDLE_FATAL(<< "Input error: unexpected analysis name '" << m_AnalysisName << "'. Please report this problem."); } CDataFrameAnalysisSpecification::TDataAdderUPtr CDataFrameAnalysisSpecification::persister() const { return m_PersisterSupplier(); } CDataFrameAnalysisSpecification::TDataSearcherUPtr CDataFrameAnalysisSpecification::restoreSearcher() const { return m_RestoreSearcherSupplier(); } CDataFrameAnalysisSpecification::TDataAdderUPtr CDataFrameAnalysisSpecification::noopPersisterSupplier() { return nullptr; } CDataFrameAnalysisSpecification::TDataSearcherUPtr CDataFrameAnalysisSpecification::noopRestoreSearcherSupplier() { return nullptr; } } }