lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc (350 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <api/CDataFrameTrainBoostedTreeClassifierRunner.h> #include <core/CBoostJsonConcurrentLineWriter.h> #include <core/CDataFrame.h> #include <core/CLogger.h> #include <core/CMemoryDef.h> #include <maths/analytics/CBoostedTree.h> #include <maths/analytics/CBoostedTreeFactory.h> #include <maths/analytics/CBoostedTreeHyperparameters.h> #include <maths/analytics/CBoostedTreeLoss.h> #include <maths/analytics/CDataFrameUtils.h> #include <maths/analytics/CTreeShapFeatureImportance.h> #include <maths/common/CLinearAlgebraEigen.h> #include <api/CBoostedTreeInferenceModelBuilder.h> #include <api/CDataFrameAnalysisConfigReader.h> #include <api/CDataFrameAnalysisSpecification.h> #include <api/ElasticsearchStateIndex.h> #include <memory> #include <numeric> #include <set> namespace ml { namespace api { namespace { using TBoolVec = std::vector<bool>; using TDoubleVec = std::vector<double>; using TSizeVec = std::vector<std::size_t>; using TStrSet = std::set<std::string>; // Output const std::string IS_TRAINING_FIELD_NAME{"is_training"}; const std::string PREDICTION_PROBABILITY_FIELD_NAME{"prediction_probability"}; const std::string PREDICTION_SCORE_FIELD_NAME{"prediction_score"}; const std::string TOP_CLASSES_FIELD_NAME{"top_classes"}; const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"}; const std::string CLASS_SCORE_FIELD_NAME{"class_score"}; const TStrSet PREDICTION_FIELD_NAME_BLACKLIST{ IS_TRAINING_FIELD_NAME, PREDICTION_PROBABILITY_FIELD_NAME, PREDICTION_SCORE_FIELD_NAME, TOP_CLASSES_FIELD_NAME}; } const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_CLASSES, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); const std::string typeString{"string"}; const std::string typeInt{"int"}; const std::string typeBool{"bool"}; theReader.addParameter(PREDICTION_FIELD_TYPE, CDataFrameAnalysisConfigReader::E_OptionalParameter, {{typeString, int{E_PredictionFieldTypeString}}, {typeInt, int{E_PredictionFieldTypeInt}}, {typeBool, int{E_PredictionFieldTypeBool}}}); int accuracy{maths::analytics::CDataFramePredictiveModel::E_Accuracy}; int recall{maths::analytics::CDataFramePredictiveModel::E_MinimumRecall}; int custom{maths::analytics::CDataFramePredictiveModel::E_Custom}; theReader.addParameter( CLASS_ASSIGNMENT_OBJECTIVE, CDataFrameAnalysisConfigReader::E_OptionalParameter, {{CLASS_ASSIGNMENT_OBJECTIVE_VALUES[accuracy], accuracy}, {CLASS_ASSIGNMENT_OBJECTIVE_VALUES[recall], recall}, {CLASS_ASSIGNMENT_OBJECTIVE_VALUES[custom], custom}}); theReader.addParameter(CLASSIFICATION_WEIGHTS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; }()}; return PARAMETER_READER; } CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifierRunner( const CDataFrameAnalysisSpecification& spec, const CDataFrameAnalysisParameters& parameters, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) : CDataFrameTrainBoostedTreeRunner{ spec, parameters, loss(parameters[NUM_CLASSES].as<std::size_t>()), frameAndDirectory} { m_NumClasses = parameters[NUM_CLASSES].as<std::size_t>(); auto classAssignmentObjective = parameters[CLASS_ASSIGNMENT_OBJECTIVE].fallback( maths::analytics::CBoostedTree::E_MinimumRecall); m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::ptrdiff_t{0}); m_PredictionFieldType = parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString); this->boostedTreeFactory().classAssignmentObjective(classAssignmentObjective); auto classificationWeights = parameters[CLASSIFICATION_WEIGHTS].fallback( CLASSIFICATION_WEIGHTS_CLASS, CLASSIFICATION_WEIGHTS_WEIGHT, std::vector<std::pair<std::string, double>>{}); if (classificationWeights.empty() == false) { this->boostedTreeFactory().classificationWeights(classificationWeights); } const TStrVec& categoricalFieldNames{spec.categoricalFieldNames()}; if (std::find(categoricalFieldNames.begin(), categoricalFieldNames.end(), this->dependentVariableFieldName()) == categoricalFieldNames.end()) { HANDLE_FATAL(<< "Input error: trying to perform classification with numeric target."); } if (PREDICTION_FIELD_NAME_BLACKLIST.count(this->predictionFieldName()) > 0) { HANDLE_FATAL(<< "Input error: " << PREDICTION_FIELD_NAME << " must not be equal to any of " << PREDICTION_FIELD_NAME_BLACKLIST << "."); } if (classificationWeights.empty() == false && classAssignmentObjective != maths::analytics::CBoostedTree::E_Custom) { HANDLE_FATAL(<< "Input error: expected " << CLASS_ASSIGNMENT_OBJECTIVE_VALUES[maths::analytics::CDataFramePredictiveModel::E_Custom] << " for " << CLASS_ASSIGNMENT_OBJECTIVE << " if supplying " << CLASSIFICATION_WEIGHTS << " but got '" << CLASS_ASSIGNMENT_OBJECTIVE_VALUES[classAssignmentObjective] << "'."); } if (classificationWeights.empty() == false && classificationWeights.size() != m_NumClasses) { HANDLE_FATAL(<< "Input error: expected " << m_NumClasses << " " << CLASSIFICATION_WEIGHTS << " but got " << classificationWeights.size() << "."); } } void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( const core::CDataFrame& frame, const TRowRef& row, core::CBoostJsonConcurrentLineWriter& writer) const { const auto& tree = this->boostedTree(); this->writeOneRow( frame, tree.columnHoldingDependentVariable(), [&](const TRowRef& row_) { return tree.prediction(row_); }, [&](const TRowRef& row_) { return tree.adjustedPrediction(row_); }, row, writer, tree.shap()); } void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( const core::CDataFrame& frame, std::size_t columnHoldingDependentVariable, const TReadPredictionFunc& readClassProbabilities, const TReadClassScoresFunc& readClassScores, const TRowRef& row, core::CBoostJsonConcurrentLineWriter& writer, maths::analytics::CTreeShapFeatureImportance* featureImportance) const { auto probabilities = readClassProbabilities(row); auto scores = readClassScores(row); double actualClassId{row[columnHoldingDependentVariable]}; std::size_t predictedClassId(std::max_element(scores.begin(), scores.end()) - scores.begin()); const TStrVec& classValues{frame.categoricalColumnValues()[columnHoldingDependentVariable]}; writer.onObjectBegin(); writer.onKey(this->predictionFieldName()); writePredictedCategoryValue(classValues[predictedClassId], writer); writer.onKey(PREDICTION_PROBABILITY_FIELD_NAME); writer.onDouble(probabilities[predictedClassId]); writer.onKey(PREDICTION_SCORE_FIELD_NAME); writer.onDouble(scores[predictedClassId]); writer.onKey(IS_TRAINING_FIELD_NAME); writer.onBool(maths::analytics::CDataFrameUtils::isMissing(actualClassId) == false); if (m_NumTopClasses != 0) { TSizeVec classIds(scores.size()); std::iota(classIds.begin(), classIds.end(), 0); std::sort(classIds.begin(), classIds.end(), [&scores](std::size_t lhs, std::size_t rhs) { return scores[lhs] > scores[rhs]; }); // -1 is a special value meaning "output all the classes" classIds.resize(m_NumTopClasses == -1 ? classIds.size() : std::min(classIds.size(), static_cast<std::size_t>(m_NumTopClasses))); writer.onKey(TOP_CLASSES_FIELD_NAME); writer.onArrayBegin(); for (std::size_t i : classIds) { writer.onObjectBegin(); writer.onKey(CLASS_NAME_FIELD_NAME); writePredictedCategoryValue(classValues[i], writer); writer.onKey(CLASS_PROBABILITY_FIELD_NAME); writer.onDouble(probabilities[i]); writer.onKey(CLASS_SCORE_FIELD_NAME); writer.onDouble(scores[i]); writer.onObjectEnd(); } writer.onArrayEnd(); } if (featureImportance != nullptr) { int numberClasses{static_cast<int>(classValues.size())}; m_InferenceModelMetadata.columnNames(featureImportance->columnNames()); m_InferenceModelMetadata.classValues(classValues); m_InferenceModelMetadata.predictionFieldTypeResolverWriter( [this](const std::string& categoryValue, core::CBoostJsonConcurrentLineWriter& writer_) { this->writePredictedCategoryValue(categoryValue, writer_); }); featureImportance->shap( row, [&](const maths::analytics::CTreeShapFeatureImportance::TSizeVec& indices, const TStrVec& featureNames, const maths::analytics::CTreeShapFeatureImportance::TVectorVec& shap) { writer.onKey(FEATURE_IMPORTANCE_FIELD_NAME); writer.onArrayBegin(); for (auto i : indices) { if (shap[i].norm() != 0.0) { writer.onObjectBegin(); writer.onKey(FEATURE_NAME_FIELD_NAME); writer.onString(featureNames[i]); if (shap[i].size() == 1) { // output feature importance for individual classes in binary case writer.onKey(CLASSES_FIELD_NAME); writer.onArrayBegin(); for (int j = 0; j < numberClasses; ++j) { writer.onObjectBegin(); writer.onKey(CLASS_NAME_FIELD_NAME); writePredictedCategoryValue(classValues[j], writer); writer.onKey(IMPORTANCE_FIELD_NAME); if (j == 1) { writer.onDouble(shap[i](0)); } else { writer.onDouble(-shap[i](0)); } writer.onObjectEnd(); } writer.onArrayEnd(); } else { // output feature importance for individual classes in multiclass case writer.onKey(CLASSES_FIELD_NAME); writer.onArrayBegin(); for (int j = 0; j < shap[i].size() && j < numberClasses; ++j) { writer.onObjectBegin(); writer.onKey(CLASS_NAME_FIELD_NAME); writePredictedCategoryValue(classValues[j], writer); writer.onKey(IMPORTANCE_FIELD_NAME); writer.onDouble(shap[i](j)); writer.onObjectEnd(); } writer.onArrayEnd(); } writer.onObjectEnd(); } } writer.onArrayEnd(); for (std::size_t i = 0; i < shap.size(); ++i) { if (shap[i].lpNorm<1>() != 0) { const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this) ->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]); } } }); } writer.onObjectEnd(); } void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CBoostJsonConcurrentLineWriter& writer) const { double doubleValue; switch (m_PredictionFieldType) { case E_PredictionFieldTypeString: writer.onString(categoryValue); break; case E_PredictionFieldTypeInt: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.onInt64(static_cast<std::int64_t>(doubleValue)); } else { writer.onString(categoryValue); } break; case E_PredictionFieldTypeBool: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.onBool(doubleValue != 0.0); } else { writer.onString(categoryValue); } break; } } CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr CDataFrameTrainBoostedTreeClassifierRunner::loss(std::size_t numberClasses) { using namespace maths::analytics::boosted_tree; return numberClasses == 2 ? TLossFunctionUPtr{std::make_unique<CBinomialLogisticLoss>()} : TLossFunctionUPtr{std::make_unique<CMultinomialLogisticLoss>(numberClasses)}; } void CDataFrameTrainBoostedTreeClassifierRunner::validate(const core::CDataFrame& frame, std::size_t dependentVariableColumn) const { std::size_t categoryCount{ frame.categoricalColumnValues()[dependentVariableColumn].size()}; if (categoryCount < 2) { HANDLE_FATAL(<< "Input error: can't run classification unless there are at least " << "two classes. Trying to predict '" << frame.columnNames()[dependentVariableColumn] << "' which has '" << categoryCount << "' categories in the training data. " << "The number of rows read is '" << frame.numberRows() << "'."); } else if (categoryCount > MAX_NUMBER_CLASSES) { HANDLE_FATAL(<< "Input error: the maximum number of classes supported is " << MAX_NUMBER_CLASSES << ". Trying to predict '" << frame.columnNames()[dependentVariableColumn] << "' which has '" << categoryCount << "' categories in the training data. " << "The number of rows read is '" << frame.numberRows() << "'."); } else if (categoryCount != m_NumClasses) { HANDLE_FATAL(<< "Input error: " << m_NumClasses << " provided for " << NUM_CLASSES << " but there are " << categoryCount << " in the data: " << frame.categoricalColumnValues()[dependentVariableColumn] << "."); } } CDataFrameAnalysisRunner::TInferenceModelDefinitionUPtr CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition( const CDataFrameAnalysisRunner::TStrVec& fieldNames, const CDataFrameAnalysisRunner::TStrVecVec& categoryNames) const { CClassificationInferenceModelBuilder builder( fieldNames, this->boostedTree().columnHoldingDependentVariable(), categoryNames); this->accept(builder); return std::make_unique<CInferenceModelDefinition>(builder.build()); } const CInferenceModelMetadata* CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const { auto* featureImportance = this->boostedTree().shap(); if (featureImportance != nullptr) { m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline()); } switch (this->task()) { case api_t::E_Encode: case api_t::E_Predict: break; case api_t::E_Train: case api_t::E_Update: m_InferenceModelMetadata.hyperparameterImportance( this->boostedTree().hyperparameterImportance()); } m_InferenceModelMetadata.numTrainRows(this->boostedTree().numberTrainRows()); m_InferenceModelMetadata.lossGap(this->boostedTree().lossGap()); m_InferenceModelMetadata.numDataSummarizationRows(static_cast<std::size_t>( this->boostedTree().dataSummarization().manhattan())); m_InferenceModelMetadata.trainedModelMemoryUsage( core::memory::dynamicSize(this->boostedTree().trainedModel())); return &m_InferenceModelMetadata; } // clang-format off // The MAX_NUMBER_CLASSES must match the value used in the Java code. See the // MAX_DEPENDENT_VARIABLE_CARDINALITY in the x-pack classification code. const std::size_t CDataFrameTrainBoostedTreeClassifierRunner::MAX_NUMBER_CLASSES{100}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::NUM_CLASSES{"num_classes"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::NUM_TOP_CLASSES{"num_top_classes"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::PREDICTION_FIELD_TYPE{"prediction_field_type"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASS_ASSIGNMENT_OBJECTIVE{"class_assignment_objective"}; const CDataFrameTrainBoostedTreeClassifierRunner::TStrVec CDataFrameTrainBoostedTreeClassifierRunner::CLASS_ASSIGNMENT_OBJECTIVE_VALUES{ "maximize_accuracy", "maximize_minimum_recall", "custom"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSIFICATION_WEIGHTS{"classification_weights"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSIFICATION_WEIGHTS_CLASS{"class"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSIFICATION_WEIGHTS_WEIGHT{"weight"}; // clang-format on const std::string& CDataFrameTrainBoostedTreeClassifierRunnerFactory::name() const { return NAME; } CDataFrameTrainBoostedTreeClassifierRunnerFactory::TRunnerUPtr CDataFrameTrainBoostedTreeClassifierRunnerFactory::makeImpl( const CDataFrameAnalysisSpecification&, TDataFrameUPtrTemporaryDirectoryPtrPr*) const { HANDLE_FATAL(<< "Input error: classification has a non-optional parameter '" << CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME << "'."); return nullptr; } CDataFrameTrainBoostedTreeClassifierRunnerFactory::TRunnerUPtr CDataFrameTrainBoostedTreeClassifierRunnerFactory::makeImpl( const CDataFrameAnalysisSpecification& spec, const json::value& jsonParameters, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) const { const CDataFrameAnalysisConfigReader& parameterReader{ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader()}; auto parameters = parameterReader.read(jsonParameters); return std::make_unique<CDataFrameTrainBoostedTreeClassifierRunner>( spec, parameters, frameAndDirectory); } const std::string CDataFrameTrainBoostedTreeClassifierRunnerFactory::NAME{"classification"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASSES_FIELD_NAME{"classes"}; const std::string CDataFrameTrainBoostedTreeClassifierRunner::CLASS_NAME_FIELD_NAME{"class_name"}; } }