void CBoostedTreeImpl::train()

in lib/maths/analytics/CBoostedTreeImpl.cc [293:444]


void CBoostedTreeImpl::train(core::CDataFrame& frame,
                             const TTrainingStateCallback& recordTrainStateCallback) {

    this->checkTrainInvariants(frame);

    m_Instrumentation->type(
        m_Loss->isRegression()
            ? CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression
            : CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification);

    LOG_TRACE(<< "Main training loop...");

    m_TrainingProgress.progressCallback(m_Instrumentation->progressCallback());

    std::int64_t lastMemoryUsage(this->memoryUsage());

    core::CPackedBitVector allTrainingRowMask{this->allTrainingRowMask()};
    core::CPackedBitVector noRowsMask{allTrainingRowMask.size(), false};

    this->startProgressMonitoringFineTuneHyperparameters();

    if (this->canTrain() == false) {

        // Fallback to using the constant predictor which minimises the loss.

        this->startProgressMonitoringFinalTrain();
        m_BestForest.assign(1, this->initializePredictionsAndLossDerivatives(
                                   frame, allTrainingRowMask, noRowsMask));
        TMeanVarAccumulator testLossMoments;
        testLossMoments += this->meanLoss(frame, allTrainingRowMask);
        m_Hyperparameters.captureBest(
            testLossMoments, 0.0 /*no loss gap*/, 0.0 /*no kept nodes*/,
            1.0 /*single node used to centre the data*/, 1 /*single tree*/);
        LOG_TRACE(<< "Test loss = " << m_Hyperparameters.bestForestTestLoss());

    } else if (m_Hyperparameters.fineTuneSearchNotFinished() || m_BestForest.empty()) {
        TMeanVarAccumulator timeAccumulator;
        core::CStopWatch stopWatch;
        stopWatch.start();
        std::uint64_t lastLap{stopWatch.lap()};

        // Hyperparameter optimisation loop.

        TDoubleVec minTestLosses{this->initializePerFoldTestLosses()};

        for (m_Hyperparameters.startFineTuneSearch();
             m_Hyperparameters.fineTuneSearchNotFinished();
             /**/) {

            LOG_TRACE(<< "Optimisation round = " << m_Hyperparameters.currentRound() + 1);
            m_Instrumentation->iteration(m_Hyperparameters.currentRound() + 1);

            this->recordHyperparameters();

            auto crossValidationResult = this->crossValidateForest(
                frame, m_Hyperparameters.maximumNumberTrees().value(),
                [this](core::CDataFrame& frame_, const core::CPackedBitVector& trainingRowMask,
                       const core::CPackedBitVector& testingRowMask,
                       double minTestLoss, core::CLoopProgress& trainingProgress) {
                    return this->trainForest(frame_, trainingRowMask, testingRowMask,
                                             trainingProgress, minTestLoss);
                },
                minTestLosses);

            // If we're evaluating using a hold-out set we will not retrain on the
            // full data set at the end.
            if (m_Hyperparameters.captureBest(
                    crossValidationResult.s_TestLossMoments,
                    crossValidationResult.s_MeanLossGap, 0.0 /*no kept nodes*/,
                    crossValidationResult.s_NumberNodes, crossValidationResult.s_NumberTrees) &&
                m_UserSuppliedHoldOutSet) {
                m_BestForest = std::move(crossValidationResult.s_Forest);
            }

            if (m_Hyperparameters.selectNext(crossValidationResult.s_TestLossMoments,
                                             this->betweenFoldTestLossVariance()) == false) {
                LOG_INFO(<< "Stopping fine-tune hyperparameters on round "
                         << m_Hyperparameters.currentRound() << " out of "
                         << m_Hyperparameters.numberRounds());
                break;
            }

            std::int64_t memoryUsage(this->memoryUsage());
            m_Instrumentation->updateMemoryUsage(memoryUsage - lastMemoryUsage);
            lastMemoryUsage = memoryUsage;

            // We need to update the current round before we persist so we don't
            // perform an extra round when we fail over.
            m_Hyperparameters.startNextRound();

            // Store the training state after each hyperparameter search step.
            LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound()
                      << " state recording started");
            this->recordState(recordTrainStateCallback);
            LOG_TRACE(<< "Round " << m_Hyperparameters.currentRound()
                      << " state recording finished");

            std::uint64_t currentLap{stopWatch.lap()};
            std::uint64_t delta{currentLap - lastLap};
            m_Instrumentation->iterationTime(delta);

            timeAccumulator.add(static_cast<double>(delta));
            lastLap = currentLap;
            m_Instrumentation->flush(HYPERPARAMETER_OPTIMIZATION_ROUND +
                                     std::to_string(m_Hyperparameters.currentRound()));
        }

        LOG_TRACE(<< "Test loss = " << m_Hyperparameters.bestForestTestLoss());

        if (m_BestForest.empty()) {
            m_Hyperparameters.restoreBest();
            m_Hyperparameters.recordHyperparameters(*m_Instrumentation);
            m_Hyperparameters.captureScale();
            this->startProgressMonitoringFinalTrain();
            this->scaleRegularizationMultipliers(allTrainingRowMask.manhattan() /
                                                 this->meanNumberTrainingRowsPerFold());

            // Reinitialize random number generator for reproducible results.
            m_Rng.seed(m_Seed);

            m_BestForest = this->trainForest(frame, allTrainingRowMask,
                                             allTrainingRowMask, m_TrainingProgress)
                               .s_Forest;

            this->recordState(recordTrainStateCallback);
        } else {
            this->skipProgressMonitoringFinalTrain();
        }
        m_Instrumentation->iteration(m_Hyperparameters.currentRound());
        m_Instrumentation->flush(TRAIN_FINAL_FOREST);

        timeAccumulator.add(static_cast<double>(stopWatch.stop() - lastLap));

        LOG_TRACE(<< "Training finished after " << m_Hyperparameters.currentRound()
                  << " iterations. Time per iteration in ms mean: "
                  << common::CBasicStatistics::mean(timeAccumulator) << " std. dev:  "
                  << std::sqrt(common::CBasicStatistics::variance(timeAccumulator)));

        core::CProgramCounters::counter(counter_t::E_DFTPMTrainedForestNumberTrees) =
            m_BestForest.size();
    } else {
        this->skipProgressMonitoringFinalTrain();
    }

    this->computeClassificationWeights(frame);
    this->initializeTreeShap(frame);

    // Force progress to one and record the final memory usage.
    m_Instrumentation->updateProgress(1.0);
    m_Instrumentation->updateMemoryUsage(
        static_cast<std::int64_t>(this->memoryUsage()) - lastMemoryUsage);
}