lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc (186 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <api/CDataFrameTrainBoostedTreeRegressionRunner.h> #include <core/CBoostJsonConcurrentLineWriter.h> #include <core/CLogger.h> #include <maths/analytics/CBoostedTree.h> #include <maths/analytics/CBoostedTreeFactory.h> #include <maths/analytics/CBoostedTreeHyperparameters.h> #include <maths/analytics/CBoostedTreeLoss.h> #include <maths/analytics/CDataFrameUtils.h> #include <maths/analytics/CTreeShapFeatureImportance.h> #include <api/CBoostedTreeInferenceModelBuilder.h> #include <api/CDataFrameAnalysisConfigReader.h> #include <api/CDataFrameAnalysisSpecification.h> #include <api/ElasticsearchStateIndex.h> #include <cmath> #include <memory> #include <set> #include <string> namespace ml { namespace api { namespace { // Output const std::string IS_TRAINING_FIELD_NAME{"is_training"}; const std::set<std::string> PREDICTION_FIELD_NAME_BLACKLIST{IS_TRAINING_FIELD_NAME}; } const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeRegressionRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(STRATIFIED_CROSS_VALIDATION, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter( LOSS_FUNCTION, CDataFrameAnalysisConfigReader::E_OptionalParameter, {{MSE, int{TLossFunctionType::E_MseRegression}}, {MSLE, int{TLossFunctionType::E_MsleRegression}}, {PSEUDO_HUBER, int{TLossFunctionType::E_HuberRegression}}}); theReader.addParameter(LOSS_FUNCTION_PARAMETER, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; }()}; return PARAMETER_READER; } CDataFrameTrainBoostedTreeRegressionRunner::TLossFunctionUPtr CDataFrameTrainBoostedTreeRegressionRunner::lossFunction(const CDataFrameAnalysisParameters& parameters) { TLossFunctionType lossFunctionType{ parameters[LOSS_FUNCTION].fallback(TLossFunctionType::E_MseRegression)}; switch (lossFunctionType) { case TLossFunctionType::E_MsleRegression: return std::make_unique<maths::analytics::boosted_tree::CMsle>( parameters[LOSS_FUNCTION_PARAMETER].fallback(1.0)); case TLossFunctionType::E_MseRegression: return std::make_unique<maths::analytics::boosted_tree::CMse>(); case TLossFunctionType::E_HuberRegression: return std::make_unique<maths::analytics::boosted_tree::CPseudoHuber>( parameters[LOSS_FUNCTION_PARAMETER].fallback(1.0)); case TLossFunctionType::E_BinaryClassification: case TLossFunctionType::E_MulticlassClassification: LOG_ERROR(<< "Input error: regression loss type is expected but classification type is provided. Defaulting to MSE instead."); return std::make_unique<maths::analytics::boosted_tree::CMse>(); } return nullptr; } CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegressionRunner( const CDataFrameAnalysisSpecification& spec, const CDataFrameAnalysisParameters& parameters, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) : CDataFrameTrainBoostedTreeRunner{spec, parameters, lossFunction(parameters), frameAndDirectory} { this->boostedTreeFactory().stratifyRegressionCrossValidation( parameters[STRATIFIED_CROSS_VALIDATION].fallback(true)); const TStrVec& categoricalFieldNames{spec.categoricalFieldNames()}; if (std::find(categoricalFieldNames.begin(), categoricalFieldNames.end(), this->dependentVariableFieldName()) != categoricalFieldNames.end()) { HANDLE_FATAL(<< "Input error: trying to perform regression with categorical target."); } if (PREDICTION_FIELD_NAME_BLACKLIST.count(this->predictionFieldName()) > 0) { HANDLE_FATAL(<< "Input error: " << PREDICTION_FIELD_NAME << " must not be equal to any of " << PREDICTION_FIELD_NAME_BLACKLIST << "."); } } void CDataFrameTrainBoostedTreeRegressionRunner::writeOneRow( const core::CDataFrame&, const TRowRef& row, core::CBoostJsonConcurrentLineWriter& writer) const { const auto& tree = this->boostedTree(); const std::size_t columnHoldingDependentVariable{tree.columnHoldingDependentVariable()}; writer.onObjectBegin(); writer.onKey(this->predictionFieldName()); writer.onDouble(tree.prediction(row)[0]); writer.onKey(IS_TRAINING_FIELD_NAME); writer.onBool(maths::analytics::CDataFrameUtils::isMissing( row[columnHoldingDependentVariable]) == false); auto* featureImportance = tree.shap(); if (featureImportance != nullptr) { m_InferenceModelMetadata.columnNames(featureImportance->columnNames()); featureImportance->shap( row, [&writer, this]( const maths::analytics::CTreeShapFeatureImportance::TSizeVec& indices, const TStrVec& featureNames, const maths::analytics::CTreeShapFeatureImportance::TVectorVec& shap) { writer.onKey(FEATURE_IMPORTANCE_FIELD_NAME); writer.onArrayBegin(); for (auto i : indices) { if (shap[i].norm() != 0.0) { writer.onObjectBegin(); writer.onKey(FEATURE_NAME_FIELD_NAME); writer.onString(featureNames[i]); writer.onKey(IMPORTANCE_FIELD_NAME); writer.onDouble(shap[i](0)); writer.onObjectEnd(); } } writer.onArrayEnd(); for (int i = 0; i < static_cast<int>(shap.size()); ++i) { if (shap[i].lpNorm<1>() != 0) { const_cast<CDataFrameTrainBoostedTreeRegressionRunner*>(this) ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); } } }); } writer.onObjectEnd(); } void CDataFrameTrainBoostedTreeRegressionRunner::validate(const core::CDataFrame&, std::size_t) const { } CDataFrameAnalysisRunner::TInferenceModelDefinitionUPtr CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition( const CDataFrameAnalysisRunner::TStrVec& fieldNames, const CDataFrameAnalysisRunner::TStrVecVec& categoryNames) const { CRegressionInferenceModelBuilder builder( fieldNames, this->boostedTree().columnHoldingDependentVariable(), categoryNames); this->accept(builder); return std::make_unique<CInferenceModelDefinition>(builder.build()); } const CInferenceModelMetadata* CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const { auto* featureImportance = this->boostedTree().shap(); if (featureImportance != nullptr) { m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline()); } switch (this->task()) { case api_t::E_Encode: case api_t::E_Predict: break; case api_t::E_Train: case api_t::E_Update: m_InferenceModelMetadata.hyperparameterImportance( this->boostedTree().hyperparameterImportance()); break; } m_InferenceModelMetadata.numTrainRows(this->boostedTree().numberTrainRows()); m_InferenceModelMetadata.lossGap(this->boostedTree().lossGap()); m_InferenceModelMetadata.numDataSummarizationRows(static_cast<std::size_t>( this->boostedTree().dataSummarization().manhattan())); return &m_InferenceModelMetadata; } // clang-format off const std::string CDataFrameTrainBoostedTreeRegressionRunner::STRATIFIED_CROSS_VALIDATION{"stratified_cross_validation"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION{"loss_function"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::LOSS_FUNCTION_PARAMETER{"loss_function_parameter"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::MSE{"mse"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::MSLE{"msle"}; const std::string CDataFrameTrainBoostedTreeRegressionRunner::PSEUDO_HUBER{"huber"}; // clang-format on const std::string& CDataFrameTrainBoostedTreeRegressionRunnerFactory::name() const { return NAME; } CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl( const CDataFrameAnalysisSpecification&, TDataFrameUPtrTemporaryDirectoryPtrPr*) const { HANDLE_FATAL(<< "Input error: regression has a non-optional parameter '" << CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'."); return nullptr; } CDataFrameTrainBoostedTreeRegressionRunnerFactory::TRunnerUPtr CDataFrameTrainBoostedTreeRegressionRunnerFactory::makeImpl( const CDataFrameAnalysisSpecification& spec, const json::value& jsonParameters, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) const { const CDataFrameAnalysisConfigReader& parameterReader{ CDataFrameTrainBoostedTreeRegressionRunner::parameterReader()}; auto parameters = parameterReader.read(jsonParameters); return std::make_unique<CDataFrameTrainBoostedTreeRegressionRunner>( spec, parameters, frameAndDirectory); } const std::string CDataFrameTrainBoostedTreeRegressionRunnerFactory::NAME{"regression"}; } }