include/api/CDataFrameTrainBoostedTreeRunner.h (130 lines of code) (raw):
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the following additional limitation. Functionality enabled by the
* files subject to the Elastic License 2.0 may only be used in production when
* invoked by an Elasticsearch process with a license key installed that permits
* use of machine learning features. You may not use this file except in
* compliance with the Elastic License 2.0 and the foregoing additional
* limitation.
*/
#ifndef INCLUDED_ml_api_CDataFrameTrainBoostedTreeRunner_h
#define INCLUDED_ml_api_CDataFrameTrainBoostedTreeRunner_h
#include <maths/common/CBasicStatistics.h>
#include <api/ApiTypes.h>
#include <api/CDataFrameAnalysisInstrumentation.h>
#include <api/CDataFrameAnalysisRunner.h>
#include <api/CDataFrameAnalysisSpecification.h>
#include <api/ImportExport.h>
#include <boost/json.hpp>
#include <memory>
namespace ml {
namespace maths {
namespace analytics {
namespace boosted_tree {
class CLoss;
}
class CBoostedTree;
class CBoostedTreeFactory;
}
}
namespace api {
class CDataFrameAnalysisConfigReader;
class CDataFrameAnalysisParameters;
class CBoostedTreeInferenceModelBuilder;
//! \brief Runs boosted tree regression on a core::CDataFrame.
class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRunner {
public:
static const std::string RANDOM_NUMBER_GENERATOR_SEED;
static const std::string DEPENDENT_VARIABLE_NAME;
static const std::string PREDICTION_FIELD_NAME;
static const std::string DOWNSAMPLE_ROWS_PER_FEATURE;
static const std::string DOWNSAMPLE_FACTOR;
static const std::string ALPHA;
static const std::string LAMBDA;
static const std::string GAMMA;
static const std::string ETA;
static const std::string ETA_GROWTH_RATE_PER_TREE;
static const std::string RETRAINED_TREE_ETA;
static const std::string SOFT_TREE_DEPTH_LIMIT;
static const std::string SOFT_TREE_DEPTH_TOLERANCE;
static const std::string MAX_TREES;
static const std::string MAX_DEPLOYED_MODEL_SIZE;
static const std::string FEATURE_BAG_FRACTION;
static const std::string PREDICTION_CHANGE_COST;
static const std::string TREE_TOPOLOGY_CHANGE_PENALTY;
static const std::string TRAINED_MODEL_MEMORY_USAGE;
static const std::string NUM_FOLDS;
static const std::string TRAIN_FRACTION_PER_FOLD;
static const std::string NUM_HOLDOUT_ROWS;
static const std::string STOP_CROSS_VALIDATION_EARLY;
static const std::string MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER;
static const std::string BAYESIAN_OPTIMISATION_RESTARTS;
static const std::string NUM_TOP_FEATURE_IMPORTANCE_VALUES;
static const std::string TRAINING_PERCENT_FIELD_NAME;
static const std::string FEATURE_PROCESSORS;
static const std::string EARLY_STOPPING_ENABLED;
static const std::string FORCE_ACCEPT_INCREMENTAL_TRAINING;
static const std::string DISABLE_HYPERPARAMETER_SCALING;
static const std::string DATA_SUMMARIZATION_FRACTION;
static const std::string TASK;
static const std::string TASK_ENCODE;
static const std::string TASK_TRAIN;
static const std::string TASK_UPDATE;
static const std::string TASK_PREDICT;
static const std::string PREVIOUS_TRAIN_LOSS_GAP;
static const std::string PREVIOUS_TRAIN_NUM_ROWS;
static const std::string MAX_NUM_NEW_TREES;
static const std::string ROW_WEIGHT_COLUMN;
// Output
static const std::string IS_TRAINING_FIELD_NAME;
static const std::string FEATURE_NAME_FIELD_NAME;
static const std::string IMPORTANCE_FIELD_NAME;
static const std::string FEATURE_IMPORTANCE_FIELD_NAME;
public:
~CDataFrameTrainBoostedTreeRunner() override;
//! \return The number of columns this adds to the data frame.
std::size_t numberExtraColumns() const override;
//! \return The capacity of the data frame slice to use.
std::size_t dataFrameSliceCapacity() const override;
//! \return A mask of the rows of \p frame to write. This is either all rows
//! or new training data if updating.
core::CPackedBitVector rowsToWriteMask(const core::CDataFrame& frame) const override;
//! \return The boosted tree.
const maths::analytics::CBoostedTree& boostedTree() const;
//! \return Reference to the analysis state.
const CDataFrameAnalysisInstrumentation& instrumentation() const override;
//! \return Reference to the analysis state.
CDataFrameAnalysisInstrumentation& instrumentation() override;
//! \return A serialisable summarization of the training data.
TDataSummarizationJsonWriterUPtr dataSummarization() const override;
protected:
using TLossFunctionUPtr = std::unique_ptr<maths::analytics::boosted_tree::CLoss>;
protected:
CDataFrameTrainBoostedTreeRunner(const CDataFrameAnalysisSpecification& spec,
const CDataFrameAnalysisParameters& parameters,
TLossFunctionUPtr loss,
TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory);
//! \return The parameter reader handling parameters that are shared by subclasses.
static const CDataFrameAnalysisConfigReader& parameterReader();
//! \return the task to perform.
api_t::EDataFrameTrainBoostedTreeTask task() const { return m_Task; }
//! \return The name of dependent variable field.
const std::string& dependentVariableFieldName() const;
//! \return The name of prediction field.
const std::string& predictionFieldName() const;
//! \return The boosted tree factory.
const maths::analytics::CBoostedTreeFactory& boostedTreeFactory() const;
//! \return The boosted tree factory.
maths::analytics::CBoostedTreeFactory& boostedTreeFactory();
//! Validate if \p frame is suitable for running the analysis on.
bool validate(const core::CDataFrame& frame) const override;
//! Write the boosted tree and custom processors to \p builder.
void accept(CBoostedTreeInferenceModelBuilder& builder) const;
private:
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::analytics::CBoostedTreeFactory>;
using TBoostedTreeUPtr = std::unique_ptr<maths::analytics::CBoostedTree>;
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
private:
void computeAndSaveExecutionStrategy() override;
void runImpl(core::CDataFrame& frame) override;
TBoostedTreeFactoryUPtr
boostedTreeFactory(TLossFunctionUPtr loss,
TDataFrameUPtrTemporaryDirectoryPtrPr* frameAndDirectory) const;
TBoostedTreeUPtr restoreBoostedTree(core::CDataFrame& frame,
std::size_t dependentVariableColumn,
const TDataSearcherUPtr& restoreSearcher);
std::size_t estimateBookkeepingMemoryUsage(std::size_t numberPartitions,
std::size_t totalNumberRows,
std::size_t partitionNumberRows,
std::size_t numberColumns) const override;
virtual void validate(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const = 0;
private:
// Note custom config is written directly to the factory object.
api_t::EDataFrameTrainBoostedTreeTask m_Task{api_t::E_Train};
json::value m_CustomProcessors;
std::string m_DependentVariableFieldName;
std::string m_PredictionFieldName;
double m_TrainingPercent;
std::size_t m_DimensionPrediction{0};
std::size_t m_DimensionGradient{0};
std::size_t m_TrainedModelMemoryUsage{0};
TBoostedTreeFactoryUPtr m_BoostedTreeFactory;
TBoostedTreeUPtr m_BoostedTree;
CDataFrameTrainBoostedTreeInstrumentation m_Instrumentation;
};
}
}
#endif // INCLUDED_ml_api_CDataFrameTrainBoostedTreeRunner_h