in lib/maths/analytics/CBoostedTreeImpl.cc [446:622]
void CBoostedTreeImpl::trainIncremental(core::CDataFrame& frame,
const TTrainingStateCallback& recordTrainStateCallback) {
this->checkIncrementalTrainInvariants(frame);
if (m_BestForest.size() == 1 || m_NewTrainingRowMask.manhattan() == 0.0) {
return;
}
LOG_TRACE(<< "Main incremental training loop...");
this->selectTreesToRetrain(frame);
// Add dummy trees that can be replaced with the new trees in the forest.
std::size_t oldBestForestSize{m_BestForest.size()};
m_BestForest.resize(oldBestForestSize + m_MaximumNumberNewTrees);
for (auto i = oldBestForestSize; i < m_BestForest.size(); ++i) {
m_BestForest[i] = {CBoostedTreeNode(m_Loss->dimensionPrediction())};
}
m_TreesToRetrain.resize(m_TreesToRetrain.size() + m_MaximumNumberNewTrees);
std::iota(m_TreesToRetrain.end() - m_MaximumNumberNewTrees,
m_TreesToRetrain.end(), oldBestForestSize);
TNodeVecVec retrainedTrees;
std::int64_t lastMemoryUsage(this->memoryUsage());
this->startProgressMonitoringTrainIncremental();
double retrainedNumberNodes{0.0};
for (const auto& i : m_TreesToRetrain) {
retrainedNumberNodes += static_cast<double>(m_BestForest[i].size());
}
double numberKeptNodes{numberForestNodes(m_BestForest) - retrainedNumberNodes};
// Make sure that our predictions are correctly initialised before computing
// the initial loss.
auto allTrainingRowMask = this->allTrainingRowMask();
auto noRowsMask = core::CPackedBitVector{allTrainingRowMask.size(), false};
this->initializePredictionsAndLossDerivatives(frame, allTrainingRowMask, noRowsMask);
// When we decide whether to accept the results of incremental training below
// we compare the loss calculated for the best candidate forest with the loss
// calculated with the original model. Since the data summary comprises a subset
// of the training data we are in effect comparing training error on old data +
// validation error on new training data with something closer to validation
// error on all data. If we don't have much new data or the improvement we can
// make on it is small this typically causes us to reject models which actually
// perform better in test. We record gap between the train and validation loss
// on the old training data in train and add it on to the threshold to accept
// adjusting for the proportion of old training data we have.
double numberNewTrainingRows{m_NewTrainingRowMask.manhattan()};
double numberOldTrainingRows{allTrainingRowMask.manhattan() - numberNewTrainingRows};
double initialLoss{common::CBasicStatistics::mean([&] {
TMeanVarAccumulator lossMoments;
for (const auto& mask : m_TestingRowMasks) {
lossMoments += this->meanChangePenalisedLoss(frame, mask);
}
return lossMoments;
}()) +
this->expectedLossGapAfterTrainIncremental(
numberOldTrainingRows, numberNewTrainingRows)};
// Hyperparameter optimisation loop.
TDoubleVec minTestLosses{this->initializePerFoldTestLosses()};
std::size_t numberTreesToRetrain{this->numberTreesToRetrain()};
TMeanVarAccumulator timeAccumulator;
core::CStopWatch stopWatch;
stopWatch.start();
std::uint64_t lastLap{stopWatch.lap()};
LOG_TRACE(<< "Number trees to retrain = " << numberTreesToRetrain << "/"
<< m_BestForest.size());
for (m_Hyperparameters.startFineTuneSearch();
m_Hyperparameters.fineTuneSearchNotFinished();
/**/) {
LOG_TRACE(<< "Optimisation round = " << m_Hyperparameters.currentRound() + 1);
m_Instrumentation->iteration(m_Hyperparameters.currentRound() + 1);
this->recordHyperparameters();
auto crossValidationResult = this->crossValidateForest(
frame, numberTreesToRetrain,
[this](core::CDataFrame& frame_, const core::CPackedBitVector& trainingRowMask,
const core::CPackedBitVector& testingRowMask,
double /*minTestLoss*/, core::CLoopProgress& trainingProgress) {
return this->updateForest(frame_, trainingRowMask,
testingRowMask, trainingProgress);
},
minTestLosses);
// If we're evaluating using a hold-out set we will not retrain on the
// full data set at the end.
if (m_Hyperparameters.captureBest(crossValidationResult.s_TestLossMoments,
crossValidationResult.s_MeanLossGap, numberKeptNodes,
crossValidationResult.s_NumberNodes,
crossValidationResult.s_NumberTrees) &&
m_UserSuppliedHoldOutSet) {
retrainedTrees = std::move(crossValidationResult.s_Forest);
}
if (m_Hyperparameters.selectNext(crossValidationResult.s_TestLossMoments,
this->betweenFoldTestLossVariance()) == false) {
LOG_INFO(<< "Stopping fine-tune hyperparameters on round "
<< m_Hyperparameters.currentRound() << " out of "
<< m_Hyperparameters.numberRounds());
break;
}
std::int64_t memoryUsage(this->memoryUsage());
m_Instrumentation->updateMemoryUsage(memoryUsage - lastMemoryUsage);
lastMemoryUsage = memoryUsage;
// We need to update the current round before we persist so we don't
// perform an extra round when we fail over.
m_Hyperparameters.startNextRound();
LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound() << " state recording started");
this->recordState(recordTrainStateCallback);
LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound() << " state recording finished");
std::uint64_t currentLap{stopWatch.lap()};
std::uint64_t delta{currentLap - lastLap};
m_Instrumentation->iterationTime(delta);
timeAccumulator.add(static_cast<double>(delta));
lastLap = currentLap;
m_Instrumentation->flush(HYPERPARAMETER_OPTIMIZATION_ROUND +
std::to_string(m_Hyperparameters.currentRound()));
}
initialLoss += m_Hyperparameters.modelSizePenalty(numberKeptNodes, retrainedNumberNodes);
LOG_TRACE(<< "Incremental training finished after "
<< m_Hyperparameters.currentRound() << " iterations. "
<< "Time per iteration in ms mean: "
<< common::CBasicStatistics::mean(timeAccumulator) << " std. dev: "
<< std::sqrt(common::CBasicStatistics::variance(timeAccumulator)));
LOG_TRACE(<< "best forest loss = " << m_Hyperparameters.bestForestTestLoss()
<< ", initial loss = " << initialLoss);
if (m_ForceAcceptIncrementalTraining || m_Hyperparameters.bestForestTestLoss() < initialLoss) {
m_Hyperparameters.restoreBest();
m_Hyperparameters.recordHyperparameters(*m_Instrumentation);
m_Hyperparameters.captureScale();
if (retrainedTrees.empty()) {
this->scaleRegularizationMultipliers(allTrainingRowMask.manhattan() /
this->meanNumberTrainingRowsPerFold());
// Reinitialize random number generator for reproducible results.
m_Rng.seed(m_Seed);
retrainedTrees = this->updateForest(frame, allTrainingRowMask,
allTrainingRowMask, m_TrainingProgress)
.s_Forest;
}
for (std::size_t i = 0; i < retrainedTrees.size(); ++i) {
m_BestForest[m_TreesToRetrain[i]] = std::move(retrainedTrees[i]);
}
// Resize the forest to eliminate the unused dummy trees.
auto lastChangedTreeIndex = m_TreesToRetrain[retrainedTrees.size() - 1];
auto bestForestSize = std::max(lastChangedTreeIndex + 1,
m_BestForest.size() - m_MaximumNumberNewTrees);
m_BestForest.resize(bestForestSize);
}
this->computeClassificationWeights(frame);
this->initializeTreeShap(frame);
// Force progress to one and record the final memory usage.
m_Instrumentation->updateProgress(1.0);
m_Instrumentation->updateMemoryUsage(
static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}