in lib/maths/analytics/CBoostedTreeImpl.cc [293:444]
void CBoostedTreeImpl::train(core::CDataFrame& frame,
const TTrainingStateCallback& recordTrainStateCallback) {
this->checkTrainInvariants(frame);
m_Instrumentation->type(
m_Loss->isRegression()
? CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression
: CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification);
LOG_TRACE(<< "Main training loop...");
m_TrainingProgress.progressCallback(m_Instrumentation->progressCallback());
std::int64_t lastMemoryUsage(this->memoryUsage());
core::CPackedBitVector allTrainingRowMask{this->allTrainingRowMask()};
core::CPackedBitVector noRowsMask{allTrainingRowMask.size(), false};
this->startProgressMonitoringFineTuneHyperparameters();
if (this->canTrain() == false) {
// Fallback to using the constant predictor which minimises the loss.
this->startProgressMonitoringFinalTrain();
m_BestForest.assign(1, this->initializePredictionsAndLossDerivatives(
frame, allTrainingRowMask, noRowsMask));
TMeanVarAccumulator testLossMoments;
testLossMoments += this->meanLoss(frame, allTrainingRowMask);
m_Hyperparameters.captureBest(
testLossMoments, 0.0 /*no loss gap*/, 0.0 /*no kept nodes*/,
1.0 /*single node used to centre the data*/, 1 /*single tree*/);
LOG_TRACE(<< "Test loss = " << m_Hyperparameters.bestForestTestLoss());
} else if (m_Hyperparameters.fineTuneSearchNotFinished() || m_BestForest.empty()) {
TMeanVarAccumulator timeAccumulator;
core::CStopWatch stopWatch;
stopWatch.start();
std::uint64_t lastLap{stopWatch.lap()};
// Hyperparameter optimisation loop.
TDoubleVec minTestLosses{this->initializePerFoldTestLosses()};
for (m_Hyperparameters.startFineTuneSearch();
m_Hyperparameters.fineTuneSearchNotFinished();
/**/) {
LOG_TRACE(<< "Optimisation round = " << m_Hyperparameters.currentRound() + 1);
m_Instrumentation->iteration(m_Hyperparameters.currentRound() + 1);
this->recordHyperparameters();
auto crossValidationResult = this->crossValidateForest(
frame, m_Hyperparameters.maximumNumberTrees().value(),
[this](core::CDataFrame& frame_, const core::CPackedBitVector& trainingRowMask,
const core::CPackedBitVector& testingRowMask,
double minTestLoss, core::CLoopProgress& trainingProgress) {
return this->trainForest(frame_, trainingRowMask, testingRowMask,
trainingProgress, minTestLoss);
},
minTestLosses);
// If we're evaluating using a hold-out set we will not retrain on the
// full data set at the end.
if (m_Hyperparameters.captureBest(
crossValidationResult.s_TestLossMoments,
crossValidationResult.s_MeanLossGap, 0.0 /*no kept nodes*/,
crossValidationResult.s_NumberNodes, crossValidationResult.s_NumberTrees) &&
m_UserSuppliedHoldOutSet) {
m_BestForest = std::move(crossValidationResult.s_Forest);
}
if (m_Hyperparameters.selectNext(crossValidationResult.s_TestLossMoments,
this->betweenFoldTestLossVariance()) == false) {
LOG_INFO(<< "Stopping fine-tune hyperparameters on round "
<< m_Hyperparameters.currentRound() << " out of "
<< m_Hyperparameters.numberRounds());
break;
}
std::int64_t memoryUsage(this->memoryUsage());
m_Instrumentation->updateMemoryUsage(memoryUsage - lastMemoryUsage);
lastMemoryUsage = memoryUsage;
// We need to update the current round before we persist so we don't
// perform an extra round when we fail over.
m_Hyperparameters.startNextRound();
// Store the training state after each hyperparameter search step.
LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound()
<< " state recording started");
this->recordState(recordTrainStateCallback);
LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound()
<< " state recording finished");
std::uint64_t currentLap{stopWatch.lap()};
std::uint64_t delta{currentLap - lastLap};
m_Instrumentation->iterationTime(delta);
timeAccumulator.add(static_cast<double>(delta));
lastLap = currentLap;
m_Instrumentation->flush(HYPERPARAMETER_OPTIMIZATION_ROUND +
std::to_string(m_Hyperparameters.currentRound()));
}
LOG_TRACE(<< "Test loss = " << m_Hyperparameters.bestForestTestLoss());
if (m_BestForest.empty()) {
m_Hyperparameters.restoreBest();
m_Hyperparameters.recordHyperparameters(*m_Instrumentation);
m_Hyperparameters.captureScale();
this->startProgressMonitoringFinalTrain();
this->scaleRegularizationMultipliers(allTrainingRowMask.manhattan() /
this->meanNumberTrainingRowsPerFold());
// Reinitialize random number generator for reproducible results.
m_Rng.seed(m_Seed);
m_BestForest = this->trainForest(frame, allTrainingRowMask,
allTrainingRowMask, m_TrainingProgress)
.s_Forest;
this->recordState(recordTrainStateCallback);
} else {
this->skipProgressMonitoringFinalTrain();
}
m_Instrumentation->iteration(m_Hyperparameters.currentRound());
m_Instrumentation->flush(TRAIN_FINAL_FOREST);
timeAccumulator.add(static_cast<double>(stopWatch.stop() - lastLap));
LOG_TRACE(<< "Training finished after " << m_Hyperparameters.currentRound()
<< " iterations. Time per iteration in ms mean: "
<< common::CBasicStatistics::mean(timeAccumulator) << " std. dev: "
<< std::sqrt(common::CBasicStatistics::variance(timeAccumulator)));
core::CProgramCounters::counter(counter_t::E_DFTPMTrainedForestNumberTrees) =
m_BestForest.size();
} else {
this->skipProgressMonitoringFinalTrain();
}
this->computeClassificationWeights(frame);
this->initializeTreeShap(frame);
// Force progress to one and record the final memory usage.
m_Instrumentation->updateProgress(1.0);
m_Instrumentation->updateMemoryUsage(
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}