lib/api/CDataFrameTrainBoostedTreeRunner.cc (589 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <api/CDataFrameTrainBoostedTreeRunner.h> #include <core/CBoostJsonConcurrentLineWriter.h> #include <core/CDataFrame.h> #include <core/CJsonStatePersistInserter.h> #include <core/CLogger.h> #include <core/CPackedBitVector.h> #include <core/CProgramCounters.h> #include <core/CStateDecompressor.h> #include <core/CStopWatch.h> #include <core/Constants.h> #include <maths/analytics/CBoostedTree.h> #include <maths/analytics/CBoostedTreeFactory.h> #include <maths/analytics/CBoostedTreeLoss.h> #include <maths/analytics/CDataFrameUtils.h> #include <api/CBoostedTreeInferenceModelBuilder.h> #include <api/CDataFrameAnalysisConfigReader.h> #include <api/CDataFrameAnalysisSpecification.h> #include <api/CDataSummarizationJsonWriter.h> #include <api/CInferenceModelDefinition.h> #include <api/CRetrainableModelJsonReader.h> #include <api/ElasticsearchStateIndex.h> #include <boost/json.hpp> #include <limits> namespace json = boost::json; namespace ml { namespace api { namespace { const std::size_t NUMBER_ROUNDS_PER_HYPERPARAMETER_IS_UNSET{ std::numeric_limits<std::size_t>::max()}; } const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { CDataFrameAnalysisConfigReader theReader; theReader.addParameter(RANDOM_NUMBER_GENERATOR_SEED, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(DEPENDENT_VARIABLE_NAME, CDataFrameAnalysisConfigReader::E_RequiredParameter); theReader.addParameter(PREDICTION_FIELD_NAME, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(DOWNSAMPLE_ROWS_PER_FEATURE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(DOWNSAMPLE_FACTOR, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(ALPHA, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(LAMBDA, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(GAMMA, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(ETA, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(ETA_GROWTH_RATE_PER_TREE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(RETRAINED_TREE_ETA, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(SOFT_TREE_DEPTH_LIMIT, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(SOFT_TREE_DEPTH_TOLERANCE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(MAX_TREES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(MAX_DEPLOYED_MODEL_SIZE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(FEATURE_BAG_FRACTION, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(PREDICTION_CHANGE_COST, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(TREE_TOPOLOGY_CHANGE_PENALTY, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(NUM_HOLDOUT_ROWS, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(NUM_FOLDS, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(TRAIN_FRACTION_PER_FOLD, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(STOP_CROSS_VALIDATION_EARLY, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(BAYESIAN_OPTIMISATION_RESTARTS, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(NUM_TOP_FEATURE_IMPORTANCE_VALUES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(TRAINING_PERCENT_FIELD_NAME, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(FEATURE_PROCESSORS, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(EARLY_STOPPING_ENABLED, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(FORCE_ACCEPT_INCREMENTAL_TRAINING, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(DISABLE_HYPERPARAMETER_SCALING, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(DATA_SUMMARIZATION_FRACTION, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter( TASK, CDataFrameAnalysisConfigReader::E_OptionalParameter, {{TASK_ENCODE, int{api_t::EDataFrameTrainBoostedTreeTask::E_Encode}}, {TASK_TRAIN, int{api_t::EDataFrameTrainBoostedTreeTask::E_Train}}, {TASK_UPDATE, int{api_t::EDataFrameTrainBoostedTreeTask::E_Update}}, {TASK_PREDICT, int{api_t::EDataFrameTrainBoostedTreeTask::E_Predict}}}); theReader.addParameter(PREVIOUS_TRAIN_LOSS_GAP, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(PREVIOUS_TRAIN_NUM_ROWS, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(MAX_NUM_NEW_TREES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(ROW_WEIGHT_COLUMN, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; }()}; return PARAMETER_READER; } CDataFrameTrainBoostedTreeRunner::CDataFrameTrainBoostedTreeRunner( const CDataFrameAnalysisSpecification& spec, const CDataFrameAnalysisParameters& parameters, TLossFunctionUPtr loss, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) : CDataFrameAnalysisRunner{spec}, m_DimensionPrediction{loss->dimensionPrediction()}, m_DimensionGradient{loss->dimensionGradient()}, m_Instrumentation{spec.jobId(), spec.memoryLimit()} { if (loss == nullptr) { HANDLE_FATAL(<< "Internal error: must provide a loss function for training." << " Please report this problem"); return; } using TDoubleVec = std::vector<double>; m_DependentVariableFieldName = parameters[DEPENDENT_VARIABLE_NAME].as<std::string>(); m_PredictionFieldName = parameters[PREDICTION_FIELD_NAME].fallback( m_DependentVariableFieldName + "_prediction"); m_TrainingPercent = parameters[TRAINING_PERCENT_FIELD_NAME].fallback(100.0) / 100.0; m_Task = parameters[TASK].fallback(api_t::E_Train); m_TrainedModelMemoryUsage = parameters[TRAINED_MODEL_MEMORY_USAGE].fallback(std::size_t{0}); // Training parameters. auto seed = parameters[RANDOM_NUMBER_GENERATOR_SEED].fallback(std::ptrdiff_t{0}); auto numberHoldoutRows = parameters[NUM_HOLDOUT_ROWS].fallback(std::size_t{0}); auto numberFolds = parameters[NUM_FOLDS].fallback(std::size_t{0}); auto trainFractionPerFold = parameters[TRAIN_FRACTION_PER_FOLD].fallback(-1.0); auto downsampleRowsPerFeature = parameters[DOWNSAMPLE_ROWS_PER_FEATURE].fallback(std::size_t{0}); auto numberRoundsPerHyperparameter = parameters[MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER].fallback( NUMBER_ROUNDS_PER_HYPERPARAMETER_IS_UNSET); auto stopHyperparameterOptimizationEarly = parameters[EARLY_STOPPING_ENABLED].fallback(true); auto bayesianOptimisationRestarts = parameters[BAYESIAN_OPTIMISATION_RESTARTS].fallback(std::size_t{0}); auto stopCrossValidationEarly = parameters[STOP_CROSS_VALIDATION_EARLY].fallback(true); auto numberTopShapValues = parameters[NUM_TOP_FEATURE_IMPORTANCE_VALUES].fallback(std::size_t{0}); auto rowWeightColumnName = parameters[ROW_WEIGHT_COLUMN].fallback(std::string{}); auto maximumDeployedSize = parameters[MAX_DEPLOYED_MODEL_SIZE].fallback( core::constants::BYTES_IN_GIGABYTES); // Hyperparameters. auto maxTrees = parameters[MAX_TREES].fallback(std::size_t{0}); auto alpha = parameters[ALPHA].fallback(TDoubleVec{}); auto lambda = parameters[LAMBDA].fallback(TDoubleVec{}); auto gamma = parameters[GAMMA].fallback(TDoubleVec{}); auto eta = parameters[ETA].fallback(TDoubleVec{}); auto etaGrowthRatePerTree = parameters[ETA_GROWTH_RATE_PER_TREE].fallback(TDoubleVec{}); auto retrainedTreeEta = parameters[RETRAINED_TREE_ETA].fallback(TDoubleVec{}); auto softTreeDepthLimit = parameters[SOFT_TREE_DEPTH_LIMIT].fallback(TDoubleVec{}); auto softTreeDepthTolerance = parameters[SOFT_TREE_DEPTH_TOLERANCE].fallback(TDoubleVec{}); auto downsampleFactor = parameters[DOWNSAMPLE_FACTOR].fallback(TDoubleVec{}); auto featureBagFraction = parameters[FEATURE_BAG_FRACTION].fallback(TDoubleVec{}); auto predictionChangeCost = parameters[PREDICTION_CHANGE_COST].fallback(TDoubleVec{}); auto treeTopologyChangePenalty = parameters[TREE_TOPOLOGY_CHANGE_PENALTY].fallback(TDoubleVec{}); // Incremental training. auto forceAcceptIncrementalTraining = parameters[FORCE_ACCEPT_INCREMENTAL_TRAINING].fallback(false); auto disableHyperparameterScaling = parameters[DISABLE_HYPERPARAMETER_SCALING].fallback(false); auto dataSummarizationFraction = parameters[DATA_SUMMARIZATION_FRACTION].fallback(-1.0); auto previousTrainLossGap = parameters[PREVIOUS_TRAIN_LOSS_GAP].fallback(-1.0); auto previousTrainNumberRows = parameters[PREVIOUS_TRAIN_NUM_ROWS].fallback(std::size_t{0}); auto maxNumberNewTrees = parameters[MAX_NUM_NEW_TREES].fallback(std::size_t{0}); if (parameters[FEATURE_PROCESSORS].jsonObject() != nullptr) { m_CustomProcessors = *parameters[FEATURE_PROCESSORS].jsonObject(); } if (std::any_of(alpha.begin(), alpha.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << ALPHA << "' should be non-negative."); } if (std::any_of(lambda.begin(), lambda.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << LAMBDA << "' should be non-negative."); } if (std::any_of(gamma.begin(), gamma.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << GAMMA << "' should be non-negative."); } if (std::any_of(eta.begin(), eta.end(), [](double x) { return x <= 0.0 || x > 1.0; })) { HANDLE_FATAL(<< "Input error: '" << ETA << "' should be in the range (0, 1]."); } if (std::any_of(etaGrowthRatePerTree.begin(), etaGrowthRatePerTree.end(), [](double x) { return x <= 0.0; })) { HANDLE_FATAL(<< "Input error: '" << ETA_GROWTH_RATE_PER_TREE << "' should be positive."); } if (std::any_of(retrainedTreeEta.begin(), retrainedTreeEta.end(), [](double x) { return x <= 0.0 || x > 1.0; })) { HANDLE_FATAL(<< "Input error: '" << RETRAINED_TREE_ETA << "' should be in the range (0, 1]."); } if (std::any_of(softTreeDepthLimit.begin(), softTreeDepthLimit.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << SOFT_TREE_DEPTH_LIMIT << "' should be non-negative."); } if (std::any_of(softTreeDepthTolerance.begin(), softTreeDepthTolerance.end(), [](double x) { return x <= 0.0; })) { HANDLE_FATAL(<< "Input error: '" << SOFT_TREE_DEPTH_TOLERANCE << "' should be positive."); } if (std::any_of(downsampleFactor.begin(), downsampleFactor.end(), [](double x) { return x <= 0.0 || x > 1.0; })) { HANDLE_FATAL(<< "Input error: '" << DOWNSAMPLE_FACTOR << "' should be in the range (0, 1]"); } if (std::any_of(featureBagFraction.begin(), featureBagFraction.end(), [](double x) { return x <= 0.0 || x > 1.0; })) { HANDLE_FATAL(<< "Input error: '" << FEATURE_BAG_FRACTION << "' should be in the range (0, 1]"); } if (std::any_of(predictionChangeCost.begin(), predictionChangeCost.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << PREDICTION_CHANGE_COST << "' should be non-negative"); } if (std::any_of(treeTopologyChangePenalty.begin(), treeTopologyChangePenalty.end(), [](double x) { return x < 0.0; })) { HANDLE_FATAL(<< "Input error: '" << TREE_TOPOLOGY_CHANGE_PENALTY << "' should be non-negative"); } if (rowWeightColumnName.empty() == false && (rowWeightColumnName == m_DependentVariableFieldName || std::find(spec.categoricalFieldNames().begin(), spec.categoricalFieldNames().end(), rowWeightColumnName) != spec.categoricalFieldNames().end())) { HANDLE_FATAL(<< "Input error: row weight column '" << rowWeightColumnName << "' can't be categorical or the same as the supplied '" << DEPENDENT_VARIABLE_NAME << "'."); } m_Instrumentation.task(m_Task); this->computeAndSaveExecutionStrategy(); m_BoostedTreeFactory = this->boostedTreeFactory(std::move(loss), frameAndDirectory); (*m_BoostedTreeFactory) .seed(seed) .numberHoldoutRows(numberHoldoutRows) .stopCrossValidationEarly(stopCrossValidationEarly) .analysisInstrumentation(m_Instrumentation) .trainingStateCallback(this->statePersister()) .stopHyperparameterOptimizationEarly(stopHyperparameterOptimizationEarly) .forceAcceptIncrementalTraining(forceAcceptIncrementalTraining) .disableHyperparameterScaling(disableHyperparameterScaling) .downsampleFactor(std::move(downsampleFactor)) .depthPenaltyMultiplier(std::move(alpha)) .treeSizePenaltyMultiplier(std::move(gamma)) .leafWeightPenaltyMultiplier(std::move(lambda)) .eta(std::move(eta)) .etaGrowthRatePerTree(std::move(etaGrowthRatePerTree)) .retrainedTreeEta(std::move(retrainedTreeEta)) .softTreeDepthLimit(std::move(softTreeDepthLimit)) .softTreeDepthTolerance(std::move(softTreeDepthTolerance)) .featureBagFraction(std::move(featureBagFraction)) .predictionChangeCost(std::move(predictionChangeCost)) .treeTopologyChangePenalty(std::move(treeTopologyChangePenalty)) .maximumDeployedSize(maximumDeployedSize) .rowWeightColumnName(std::move(rowWeightColumnName)); if (downsampleRowsPerFeature > 0) { m_BoostedTreeFactory->initialDownsampleRowsPerFeature( static_cast<double>(downsampleRowsPerFeature)); } if (maxTrees > 0) { m_BoostedTreeFactory->maximumNumberTrees(maxTrees); } if (numberFolds > 1) { m_BoostedTreeFactory->numberFolds(numberFolds); } if (trainFractionPerFold > 0.0) { m_BoostedTreeFactory->trainFractionPerFold(trainFractionPerFold); } if (numberRoundsPerHyperparameter != NUMBER_ROUNDS_PER_HYPERPARAMETER_IS_UNSET) { m_BoostedTreeFactory->maximumOptimisationRoundsPerHyperparameter(numberRoundsPerHyperparameter); } if (bayesianOptimisationRestarts > 0) { m_BoostedTreeFactory->bayesianOptimisationRestarts(bayesianOptimisationRestarts); } if (numberTopShapValues > 0) { m_BoostedTreeFactory->numberTopShapValues(numberTopShapValues); } if (dataSummarizationFraction > 0) { m_BoostedTreeFactory->dataSummarizationFraction(dataSummarizationFraction); } if (previousTrainLossGap > 0.0) { m_BoostedTreeFactory->previousTrainLossGap(previousTrainLossGap); } if (previousTrainNumberRows > 0) { m_BoostedTreeFactory->previousTrainNumberRows(previousTrainNumberRows); } if (maxNumberNewTrees > 0) { m_BoostedTreeFactory->maximumNumberNewTrees(maxNumberNewTrees); } } CDataFrameTrainBoostedTreeRunner::~CDataFrameTrainBoostedTreeRunner() = default; std::size_t CDataFrameTrainBoostedTreeRunner::numberExtraColumns() const { switch (m_Task) { case api_t::E_Encode: return maths::analytics::CBoostedTreeFactory::estimateExtraColumnsForEncode(); case api_t::E_Train: return maths::analytics::CBoostedTreeFactory::estimateExtraColumnsForTrain( this->spec().numberColumns(), m_DimensionPrediction, m_DimensionGradient); case api_t::E_Update: return maths::analytics::CBoostedTreeFactory::estimateExtraColumnsForTrainIncremental( this->spec().numberColumns(), m_DimensionPrediction, m_DimensionGradient); case api_t::E_Predict: return maths::analytics::CBoostedTreeFactory::estimateExtraColumnsForPredict( m_DimensionPrediction); } } std::size_t CDataFrameTrainBoostedTreeRunner::dataFrameSliceCapacity() const { std::size_t sliceCapacity{core::dataFrameDefaultSliceCapacity( this->spec().numberColumns() + this->numberExtraColumns())}; std::size_t numberThreads{this->spec().numberThreads()}; if (numberThreads > 1) { std::size_t numberRows{this->spec().numberRows()}; // Use at least one slice per thread because we parallelize work over slices. std::size_t capacityForOneSlicePerThread{(numberRows + numberThreads - 1) / numberThreads}; sliceCapacity = std::min(sliceCapacity, capacityForOneSlicePerThread); // Round the slice size so number threads is a divisor of the number of slices. std::size_t numberSlices{numberRows / sliceCapacity}; sliceCapacity = numberRows / (numberThreads * ((numberSlices + numberThreads / 2) / numberThreads)); } return std::max(sliceCapacity, std::size_t{128}); } core::CPackedBitVector CDataFrameTrainBoostedTreeRunner::rowsToWriteMask(const core::CDataFrame& frame) const { switch (m_Task) { case api_t::E_Encode: return {frame.numberRows(), false}; case api_t::E_Train: return {frame.numberRows(), true}; case api_t::E_Predict: case api_t::E_Update: return m_BoostedTree->newTrainingRowMask(); } } const std::string& CDataFrameTrainBoostedTreeRunner::dependentVariableFieldName() const { return m_DependentVariableFieldName; } const std::string& CDataFrameTrainBoostedTreeRunner::predictionFieldName() const { return m_PredictionFieldName; } const maths::analytics::CBoostedTree& CDataFrameTrainBoostedTreeRunner::boostedTree() const { if (m_BoostedTree == nullptr) { HANDLE_FATAL(<< "Internal error: boosted tree missing. Please report this problem."); } return *m_BoostedTree; } maths::analytics::CBoostedTreeFactory& CDataFrameTrainBoostedTreeRunner::boostedTreeFactory() { if (m_BoostedTreeFactory == nullptr) { HANDLE_FATAL(<< "Internal error: boosted tree factory missing. Please report this problem."); } return *m_BoostedTreeFactory; } const maths::analytics::CBoostedTreeFactory& CDataFrameTrainBoostedTreeRunner::boostedTreeFactory() const { if (m_BoostedTreeFactory == nullptr) { HANDLE_FATAL(<< "Internal error: boosted tree factory missing. Please report this problem."); } return *m_BoostedTreeFactory; } bool CDataFrameTrainBoostedTreeRunner::validate(const core::CDataFrame& frame) const { if (frame.numberColumns() <= 1) { HANDLE_FATAL(<< "Input error: analysis need at least one regressor."); return false; } if (frame.numberRows() > maths::analytics::CBoostedTreeFactory::maximumNumberRows()) { HANDLE_FATAL(<< "Input error: no more than " << maths::analytics::CBoostedTreeFactory::maximumNumberRows() << " are supported. You need to downsample your data."); return false; } return true; } void CDataFrameTrainBoostedTreeRunner::accept(CBoostedTreeInferenceModelBuilder& builder) const { if (m_CustomProcessors.is_null() == false) { builder.addCustomProcessor(std::make_unique<COpaqueEncoding>(m_CustomProcessors)); } this->boostedTree().accept(builder); } void CDataFrameTrainBoostedTreeRunner::computeAndSaveExecutionStrategy() { // We always use in core storage for the data frame for boosted tree training // because it is too slow to use disk. this->numberPartitions(1); this->maximumNumberRowsPerPartition(this->spec().numberRows()); } void CDataFrameTrainBoostedTreeRunner::runImpl(core::CDataFrame& frame) { auto dependentVariablePos = std::find(frame.columnNames().begin(), frame.columnNames().end(), m_DependentVariableFieldName); if (dependentVariablePos == frame.columnNames().end()) { HANDLE_FATAL(<< "Input error: supplied variable to predict '" << m_DependentVariableFieldName << "' is missing from training" << " data " << frame.columnNames()); return; } core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage) = this->estimateMemoryUsage(frame.numberRows(), frame.numberRows() / this->numberPartitions(), frame.numberColumns() + this->numberExtraColumns()); core::CStopWatch watch{true}; std::size_t dependentVariableColumn(dependentVariablePos - frame.columnNames().begin()); this->validate(frame, dependentVariableColumn); switch (m_Task) { case api_t::E_Encode: m_BoostedTree = m_BoostedTreeFactory->buildForEncode(frame, dependentVariableColumn); break; case api_t::E_Train: { auto restoreSearcher = this->spec().restoreSearcher(); auto boostedTree = (restoreSearcher == nullptr) ? nullptr : this->restoreBoostedTree(frame, dependentVariableColumn, restoreSearcher); m_BoostedTree = [&] { return boostedTree != nullptr ? std::move(boostedTree) : m_BoostedTreeFactory->buildForTrain(frame, dependentVariableColumn); }(); m_BoostedTree->train(); m_BoostedTree->predict(); } break; case api_t::E_Update: m_BoostedTree = m_BoostedTreeFactory->buildForTrainIncremental(frame, dependentVariableColumn); m_BoostedTree->trainIncremental(); m_BoostedTree->predict(true /*new data only*/); break; case api_t::E_Predict: m_BoostedTree = m_BoostedTreeFactory->buildForPredict(frame, dependentVariableColumn); // Prediction occurs in buildForPredict. // m_BoostedTree->predict(true /*new data only*/); break; } core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain) = watch.stop(); } CDataFrameTrainBoostedTreeRunner::TBoostedTreeFactoryUPtr CDataFrameTrainBoostedTreeRunner::boostedTreeFactory(TLossFunctionUPtr loss, TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) const { switch (m_Task) { case api_t::E_Encode: case api_t::E_Train: break; case api_t::E_Update: case api_t::E_Predict: if (frameAndDirectory != nullptr) { // This will be null if we're just computing memory usage. auto restoreSearcher = this->spec().restoreSearcher(); if (restoreSearcher == nullptr) { HANDLE_FATAL(<< "Input error: can't predict or incrementally training without supplying a model."); break; } *frameAndDirectory = this->makeDataFrame(); auto bestForestRestorer = [](CRetrainableModelJsonReader::TIStreamSPtr inputStream, const CRetrainableModelJsonReader::TStrSizeUMap& encodingsIndices) { return CRetrainableModelJsonReader::bestForestFromCompressedJsonStream( std::move(inputStream), encodingsIndices); }; auto dataSummarizationRestorer = [](CRetrainableModelJsonReader::TIStreamSPtr inputStream, core::CDataFrame& frame) { return CRetrainableModelJsonReader::dataSummarizationFromCompressedJsonStream( std::move(inputStream), frame); }; auto& frame = frameAndDirectory->first; auto result = std::make_unique<maths::analytics::CBoostedTreeFactory>( maths::analytics::CBoostedTreeFactory::constructFromDefinition( this->spec().numberThreads(), std::move(loss), *restoreSearcher, *frame, dataSummarizationRestorer, bestForestRestorer)); result->newTrainingRowMask(core::CPackedBitVector{frame->numberRows(), false}); return result; } break; } return std::make_unique<maths::analytics::CBoostedTreeFactory>( maths::analytics::CBoostedTreeFactory::constructFromParameters( this->spec().numberThreads(), std::move(loss))); } CDataFrameTrainBoostedTreeRunner::TBoostedTreeUPtr CDataFrameTrainBoostedTreeRunner::restoreBoostedTree(core::CDataFrame& frame, std::size_t dependentVariableColumn, const TDataSearcherUPtr& restoreSearcher) { if (restoreSearcher == nullptr) { return nullptr; } // Restore from compressed JSON. try { core::CStateDecompressor decompressor{*restoreSearcher}; core::CDataSearcher::TIStreamP inputStream{decompressor.search(1, 1)}; // search arguments are ignored if (inputStream == nullptr) { LOG_ERROR(<< "Unable to connect to data store"); return nullptr; } if (inputStream->bad()) { LOG_ERROR(<< "State restoration search returned bad stream"); return nullptr; } if (inputStream->fail()) { // This is fatal. If the stream exists and has failed then state is missing LOG_ERROR(<< "State restoration search returned failed stream"); return nullptr; } return maths::analytics::CBoostedTreeFactory::constructFromString(*inputStream) .analysisInstrumentation(m_Instrumentation) .trainingStateCallback(this->statePersister()) .restoreFor(frame, dependentVariableColumn); } catch (std::exception& e) { LOG_ERROR(<< "Failed to restore state! " << e.what()); } return nullptr; } std::size_t CDataFrameTrainBoostedTreeRunner::estimateBookkeepingMemoryUsage( std::size_t /*numberPartitions*/, std::size_t totalNumberRows, std::size_t /*partitionNumberRows*/, std::size_t numberColumns) const { std::size_t numberTrainingRows{static_cast<std::size_t>( static_cast<double>(totalNumberRows) * m_TrainingPercent + 0.5)}; switch (m_Task) { case api_t::E_Encode: return m_BoostedTreeFactory->estimateMemoryUsageForEncode( numberTrainingRows, numberColumns, this->spec().categoricalFieldNames().size()); case api_t::E_Train: return m_BoostedTreeFactory->estimateMemoryUsageForTrain(numberTrainingRows, numberColumns); case api_t::E_Update: return m_TrainedModelMemoryUsage + m_BoostedTreeFactory->estimateMemoryUsageForTrainIncremental( numberTrainingRows, numberColumns); case api_t::E_Predict: return m_TrainedModelMemoryUsage + m_BoostedTreeFactory->estimateMemoryUsageForPredict( numberTrainingRows, numberColumns); } } const CDataFrameAnalysisInstrumentation& CDataFrameTrainBoostedTreeRunner::instrumentation() const { return m_Instrumentation; } CDataFrameAnalysisInstrumentation& CDataFrameTrainBoostedTreeRunner::instrumentation() { return m_Instrumentation; } CDataFrameAnalysisRunner::TDataSummarizationJsonWriterUPtr CDataFrameTrainBoostedTreeRunner::dataSummarization() const { auto rowMask = this->boostedTree().dataSummarization(); if (rowMask.manhattan() <= 0.0) { return {}; } return std::make_unique<CDataSummarizationJsonWriter>( this->boostedTree().trainingData(), std::move(rowMask), this->spec().numberColumns(), this->boostedTree().categoryEncoder()); } // clang-format off const std::string CDataFrameTrainBoostedTreeRunner::RANDOM_NUMBER_GENERATOR_SEED{"randomize_seed"}; const std::string CDataFrameTrainBoostedTreeRunner::DEPENDENT_VARIABLE_NAME{"dependent_variable"}; const std::string CDataFrameTrainBoostedTreeRunner::PREDICTION_FIELD_NAME{"prediction_field_name"}; const std::string CDataFrameTrainBoostedTreeRunner::TRAINING_PERCENT_FIELD_NAME{"training_percent"}; const std::string CDataFrameTrainBoostedTreeRunner::DOWNSAMPLE_ROWS_PER_FEATURE{"downsample_rows_per_feature"}; const std::string CDataFrameTrainBoostedTreeRunner::DOWNSAMPLE_FACTOR{"downsample_factor"}; const std::string CDataFrameTrainBoostedTreeRunner::ALPHA{"alpha"}; const std::string CDataFrameTrainBoostedTreeRunner::LAMBDA{"lambda"}; const std::string CDataFrameTrainBoostedTreeRunner::GAMMA{"gamma"}; const std::string CDataFrameTrainBoostedTreeRunner::ETA{"eta"}; const std::string CDataFrameTrainBoostedTreeRunner::ETA_GROWTH_RATE_PER_TREE{"eta_growth_rate_per_tree"}; const std::string CDataFrameTrainBoostedTreeRunner::RETRAINED_TREE_ETA{"retrained_tree_eta"}; const std::string CDataFrameTrainBoostedTreeRunner::SOFT_TREE_DEPTH_LIMIT{"soft_tree_depth_limit"}; const std::string CDataFrameTrainBoostedTreeRunner::SOFT_TREE_DEPTH_TOLERANCE{"soft_tree_depth_tolerance"}; const std::string CDataFrameTrainBoostedTreeRunner::MAX_TREES{"max_trees"}; const std::string CDataFrameTrainBoostedTreeRunner::MAX_DEPLOYED_MODEL_SIZE{"max_model_size"}; const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_BAG_FRACTION{"feature_bag_fraction"}; const std::string CDataFrameTrainBoostedTreeRunner::PREDICTION_CHANGE_COST{"prediction_change_cost"}; const std::string CDataFrameTrainBoostedTreeRunner::TREE_TOPOLOGY_CHANGE_PENALTY{"tree_topology_change_penalty"}; const std::string CDataFrameTrainBoostedTreeRunner::TRAINED_MODEL_MEMORY_USAGE{"trained_model_memory_usage"}; const std::string CDataFrameTrainBoostedTreeRunner::NUM_HOLDOUT_ROWS{"num_holdout_rows"}; const std::string CDataFrameTrainBoostedTreeRunner::NUM_FOLDS{"num_folds"}; const std::string CDataFrameTrainBoostedTreeRunner::TRAIN_FRACTION_PER_FOLD{"train_fraction_per_fold"}; const std::string CDataFrameTrainBoostedTreeRunner::STOP_CROSS_VALIDATION_EARLY{"stop_cross_validation_early"}; const std::string CDataFrameTrainBoostedTreeRunner::MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER{"max_optimization_rounds_per_hyperparameter"}; const std::string CDataFrameTrainBoostedTreeRunner::BAYESIAN_OPTIMISATION_RESTARTS{"bayesian_optimisation_restarts"}; const std::string CDataFrameTrainBoostedTreeRunner::NUM_TOP_FEATURE_IMPORTANCE_VALUES{"num_top_feature_importance_values"}; const std::string CDataFrameTrainBoostedTreeRunner::IS_TRAINING_FIELD_NAME{"is_training"}; const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME{"feature_name"}; const std::string CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME{"importance"}; const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME{"feature_importance"}; const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_PROCESSORS{"feature_processors"}; const std::string CDataFrameTrainBoostedTreeRunner::EARLY_STOPPING_ENABLED{"early_stopping_enabled"}; const std::string CDataFrameTrainBoostedTreeRunner::FORCE_ACCEPT_INCREMENTAL_TRAINING{"force_accept_incremental_training"}; const std::string CDataFrameTrainBoostedTreeRunner::DISABLE_HYPERPARAMETER_SCALING{"disable_hyperparameter_scaling"}; const std::string CDataFrameTrainBoostedTreeRunner::DATA_SUMMARIZATION_FRACTION{"data_summarization_fraction"}; const std::string CDataFrameTrainBoostedTreeRunner::TASK{"task"}; const std::string CDataFrameTrainBoostedTreeRunner::TASK_ENCODE{"encode"}; const std::string CDataFrameTrainBoostedTreeRunner::TASK_TRAIN{"train"}; const std::string CDataFrameTrainBoostedTreeRunner::TASK_UPDATE{"update"}; const std::string CDataFrameTrainBoostedTreeRunner::TASK_PREDICT{"predict"}; const std::string CDataFrameTrainBoostedTreeRunner::PREVIOUS_TRAIN_LOSS_GAP{"previous_train_loss_gap"}; const std::string CDataFrameTrainBoostedTreeRunner::PREVIOUS_TRAIN_NUM_ROWS{"previous_train_num_rows"}; const std::string CDataFrameTrainBoostedTreeRunner::MAX_NUM_NEW_TREES{"max_num_new_trees"}; const std::string CDataFrameTrainBoostedTreeRunner::ROW_WEIGHT_COLUMN{"weight_column"}; // clang-format on } }