void CBoostedTreeImpl::trainIncremental()

in lib/maths/analytics/CBoostedTreeImpl.cc [446:622]


void CBoostedTreeImpl::trainIncremental(core::CDataFrame& frame,
                                        const TTrainingStateCallback& recordTrainStateCallback) {

    this->checkIncrementalTrainInvariants(frame);

    if (m_BestForest.size() == 1 || m_NewTrainingRowMask.manhattan() == 0.0) {
        return;
    }

    LOG_TRACE(<< "Main incremental training loop...");

    this->selectTreesToRetrain(frame);
    // Add dummy trees that can be replaced with the new trees in the forest.
    std::size_t oldBestForestSize{m_BestForest.size()};
    m_BestForest.resize(oldBestForestSize + m_MaximumNumberNewTrees);
    for (auto i = oldBestForestSize; i < m_BestForest.size(); ++i) {
        m_BestForest[i] = {CBoostedTreeNode(m_Loss->dimensionPrediction())};
    }
    m_TreesToRetrain.resize(m_TreesToRetrain.size() + m_MaximumNumberNewTrees);
    std::iota(m_TreesToRetrain.end() - m_MaximumNumberNewTrees,
              m_TreesToRetrain.end(), oldBestForestSize);
    TNodeVecVec retrainedTrees;

    std::int64_t lastMemoryUsage(this->memoryUsage());

    this->startProgressMonitoringTrainIncremental();

    double retrainedNumberNodes{0.0};
    for (const auto& i : m_TreesToRetrain) {
        retrainedNumberNodes += static_cast<double>(m_BestForest[i].size());
    }
    double numberKeptNodes{numberForestNodes(m_BestForest) - retrainedNumberNodes};

    // Make sure that our predictions are correctly initialised before computing
    // the initial loss.
    auto allTrainingRowMask = this->allTrainingRowMask();
    auto noRowsMask = core::CPackedBitVector{allTrainingRowMask.size(), false};
    this->initializePredictionsAndLossDerivatives(frame, allTrainingRowMask, noRowsMask);

    // When we decide whether to accept the results of incremental training below
    // we compare the loss calculated for the best candidate forest with the loss
    // calculated with the original model. Since the data summary comprises a subset
    // of the training data we are in effect comparing training error on old data +
    // validation error on new training data with something closer to validation
    // error on all data. If we don't have much new data or the improvement we can
    // make on it is small this typically causes us to reject models which actually
    // perform better in test. We record gap between the train and validation loss
    // on the old training data in train and add it on to the threshold to accept
    // adjusting for the proportion of old training data we have.
    double numberNewTrainingRows{m_NewTrainingRowMask.manhattan()};
    double numberOldTrainingRows{allTrainingRowMask.manhattan() - numberNewTrainingRows};
    double initialLoss{common::CBasicStatistics::mean([&] {
                           TMeanVarAccumulator lossMoments;
                           for (const auto& mask : m_TestingRowMasks) {
                               lossMoments += this->meanChangePenalisedLoss(frame, mask);
                           }
                           return lossMoments;
                       }()) +
                       this->expectedLossGapAfterTrainIncremental(
                           numberOldTrainingRows, numberNewTrainingRows)};

    // Hyperparameter optimisation loop.

    TDoubleVec minTestLosses{this->initializePerFoldTestLosses()};

    std::size_t numberTreesToRetrain{this->numberTreesToRetrain()};
    TMeanVarAccumulator timeAccumulator;
    core::CStopWatch stopWatch;
    stopWatch.start();
    std::uint64_t lastLap{stopWatch.lap()};
    LOG_TRACE(<< "Number trees to retrain = " << numberTreesToRetrain << "/"
              << m_BestForest.size());

    for (m_Hyperparameters.startFineTuneSearch();
         m_Hyperparameters.fineTuneSearchNotFinished();
         /**/) {

        LOG_TRACE(<< "Optimisation round = " << m_Hyperparameters.currentRound() + 1);
        m_Instrumentation->iteration(m_Hyperparameters.currentRound() + 1);

        this->recordHyperparameters();

        auto crossValidationResult = this->crossValidateForest(
            frame, numberTreesToRetrain,
            [this](core::CDataFrame& frame_, const core::CPackedBitVector& trainingRowMask,
                   const core::CPackedBitVector& testingRowMask,
                   double /*minTestLoss*/, core::CLoopProgress& trainingProgress) {
                return this->updateForest(frame_, trainingRowMask,
                                          testingRowMask, trainingProgress);
            },
            minTestLosses);

        // If we're evaluating using a hold-out set we will not retrain on the
        // full data set at the end.
        if (m_Hyperparameters.captureBest(crossValidationResult.s_TestLossMoments,
                                          crossValidationResult.s_MeanLossGap, numberKeptNodes,
                                          crossValidationResult.s_NumberNodes,
                                          crossValidationResult.s_NumberTrees) &&
            m_UserSuppliedHoldOutSet) {
            retrainedTrees = std::move(crossValidationResult.s_Forest);
        }

        if (m_Hyperparameters.selectNext(crossValidationResult.s_TestLossMoments,
                                         this->betweenFoldTestLossVariance()) == false) {
            LOG_INFO(<< "Stopping fine-tune hyperparameters on round "
                     << m_Hyperparameters.currentRound() << " out of "
                     << m_Hyperparameters.numberRounds());
            break;
        }

        std::int64_t memoryUsage(this->memoryUsage());
        m_Instrumentation->updateMemoryUsage(memoryUsage - lastMemoryUsage);
        lastMemoryUsage = memoryUsage;

        // We need to update the current round before we persist so we don't
        // perform an extra round when we fail over.
        m_Hyperparameters.startNextRound();

        LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound() << " state recording started");
        this->recordState(recordTrainStateCallback);
        LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound() << " state recording finished");

        std::uint64_t currentLap{stopWatch.lap()};
        std::uint64_t delta{currentLap - lastLap};
        m_Instrumentation->iterationTime(delta);

        timeAccumulator.add(static_cast<double>(delta));
        lastLap = currentLap;
        m_Instrumentation->flush(HYPERPARAMETER_OPTIMIZATION_ROUND +
                                 std::to_string(m_Hyperparameters.currentRound()));
    }

    initialLoss += m_Hyperparameters.modelSizePenalty(numberKeptNodes, retrainedNumberNodes);

    LOG_TRACE(<< "Incremental training finished after "
              << m_Hyperparameters.currentRound() << " iterations. "
              << "Time per iteration in ms mean: "
              << common::CBasicStatistics::mean(timeAccumulator) << " std. dev:  "
              << std::sqrt(common::CBasicStatistics::variance(timeAccumulator)));
    LOG_TRACE(<< "best forest loss = " << m_Hyperparameters.bestForestTestLoss()
              << ", initial loss = " << initialLoss);

    if (m_ForceAcceptIncrementalTraining || m_Hyperparameters.bestForestTestLoss() < initialLoss) {
        m_Hyperparameters.restoreBest();
        m_Hyperparameters.recordHyperparameters(*m_Instrumentation);
        m_Hyperparameters.captureScale();

        if (retrainedTrees.empty()) {
            this->scaleRegularizationMultipliers(allTrainingRowMask.manhattan() /
                                                 this->meanNumberTrainingRowsPerFold());

            // Reinitialize random number generator for reproducible results.
            m_Rng.seed(m_Seed);

            retrainedTrees = this->updateForest(frame, allTrainingRowMask,
                                                allTrainingRowMask, m_TrainingProgress)
                                 .s_Forest;
        }

        for (std::size_t i = 0; i < retrainedTrees.size(); ++i) {
            m_BestForest[m_TreesToRetrain[i]] = std::move(retrainedTrees[i]);
        }
        // Resize the forest to eliminate the unused dummy trees.
        auto lastChangedTreeIndex = m_TreesToRetrain[retrainedTrees.size() - 1];
        auto bestForestSize = std::max(lastChangedTreeIndex + 1,
                                       m_BestForest.size() - m_MaximumNumberNewTrees);
        m_BestForest.resize(bestForestSize);
    }

    this->computeClassificationWeights(frame);
    this->initializeTreeShap(frame);

    // Force progress to one and record the final memory usage.
    m_Instrumentation->updateProgress(1.0);
    m_Instrumentation->updateMemoryUsage(
        static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}