void CBoostedTreeFactory::initializeUnsetRegularizationHyperparameters()

in lib/maths/analytics/CBoostedTreeFactory.cc [754:963]


void CBoostedTreeFactory::initializeUnsetRegularizationHyperparameters(core::CDataFrame& frame) {

    // The strategy here is to:
    //   1) Get percentile estimates of the gain in the loss function and its sum
    //      curvature from the splits selected in a single tree with regularisers
    //      zeroed,
    //   2) Use these to extract reasonable intervals to search for the multipliers
    //      for the various regularisation penalties,
    //   3) Line search these intervals for a turning point in the test loss, i.e.
    //      the point at which transition to overfit occurs.
    //
    // We'll search intervals in the vicinity of these values in the hyperparameter
    // optimisation loop.

    auto& hyperparameters = m_TreeImpl->m_Hyperparameters;
    auto& depthPenaltyMultiplierParameter = hyperparameters.depthPenaltyMultiplier();
    auto& leafWeightPenaltyMultiplier = hyperparameters.leafWeightPenaltyMultiplier();
    auto& softTreeDepthLimitParameter = hyperparameters.softTreeDepthLimit();
    auto& softTreeDepthToleranceParameter = hyperparameters.softTreeDepthTolerance();
    auto& treeSizePenaltyMultiplier = hyperparameters.treeSizePenaltyMultiplier();
    double log2MaxTreeSize{std::log2(static_cast<double>(m_TreeImpl->maximumTreeSize(
                               m_TreeImpl->m_TrainingRowMasks[0]))) +
                           1.0};

    skipIfAfter(CBoostedTreeImpl::E_EncodingInitialized, [&] {
        softTreeDepthLimitParameter.set(log2MaxTreeSize);
        softTreeDepthToleranceParameter.set(
            0.5 * (MIN_SOFT_DEPTH_LIMIT_TOLERANCE + MAX_SOFT_DEPTH_LIMIT_TOLERANCE));

        auto gainAndTotalCurvaturePerNode =
            this->estimateTreeGainAndCurvature(frame, {1.0, 50.0, 90.0});

        m_GainPerNode1stPercentile = gainAndTotalCurvaturePerNode[0].first;
        m_GainPerNode50thPercentile = gainAndTotalCurvaturePerNode[1].first;
        m_GainPerNode90thPercentile = gainAndTotalCurvaturePerNode[2].first;
        m_TotalCurvaturePerNode1stPercentile = gainAndTotalCurvaturePerNode[0].second;
        m_TotalCurvaturePerNode90thPercentile = gainAndTotalCurvaturePerNode[2].second;

        // Make sure all line search intervals are not empty.
        m_GainPerNode1stPercentile = common::CTools::truncate(
            m_GainPerNode1stPercentile, 1e-7 * m_GainPerNode90thPercentile,
            0.1 * m_GainPerNode90thPercentile);
        m_TotalCurvaturePerNode1stPercentile = common::CTools::truncate(
            m_TotalCurvaturePerNode1stPercentile, 1e-7 * m_TotalCurvaturePerNode90thPercentile,
            0.1 * m_TotalCurvaturePerNode90thPercentile);

        LOG_TRACE(<< "max depth = " << softTreeDepthLimitParameter.print());
        LOG_TRACE(<< "tolerance = " << softTreeDepthToleranceParameter.print());
        LOG_TRACE(<< "gains and total curvatures per node = " << gainAndTotalCurvaturePerNode);
    });

    skipIfAfter(CBoostedTreeImpl::E_EncodingInitialized, [&] {
        m_LossGap = hyperparameters.bestForestLossGap();
        m_NumberTrees = hyperparameters.maximumNumberTrees().value();

        if (m_GainPerNode90thPercentile == 0.0) {
            if (softTreeDepthLimitParameter.rangeFixed() == false) {
                softTreeDepthLimitParameter.fixTo(MIN_SOFT_DEPTH_LIMIT);
            }
            if (depthPenaltyMultiplierParameter.rangeFixed() == false) {
                depthPenaltyMultiplierParameter.fixTo(0.0);
            }
            if (treeSizePenaltyMultiplier.rangeFixed() == false) {
                treeSizePenaltyMultiplier.fixTo(0.0);
            }
        }
        if (m_TotalCurvaturePerNode90thPercentile == 0.0 &&
            leafWeightPenaltyMultiplier.rangeFixed() == false) {
            leafWeightPenaltyMultiplier.fixTo(0.0);
        }

        // Initialize regularization multipliers with their minimum permitted values.
        if (treeSizePenaltyMultiplier.rangeFixed() == false) {
            treeSizePenaltyMultiplier.set(minBoundary(
                treeSizePenaltyMultiplier, m_GainPerNode90thPercentile,
                2.0 * m_GainPerNode90thPercentile / m_GainPerNode1stPercentile));
        }
        if (leafWeightPenaltyMultiplier.rangeFixed() == false) {
            leafWeightPenaltyMultiplier.set(minBoundary(
                leafWeightPenaltyMultiplier, m_TotalCurvaturePerNode90thPercentile,
                2.0 * m_TotalCurvaturePerNode90thPercentile / m_TotalCurvaturePerNode1stPercentile));
        }
    });

    // Search for depth limit at which the tree starts to overfit.
    if (softTreeDepthLimitParameter.rangeFixed() == false) {
        if (this->skipCheckpointIfAtOrAfter(CBoostedTreeImpl::E_SoftTreeDepthLimitInitialized, [&] {
                if (m_GainPerNode90thPercentile > 0.0) {
                    double maxSoftDepthLimit{MIN_SOFT_DEPTH_LIMIT + log2MaxTreeSize};
                    double minSearchValue{softTreeDepthLimitParameter.toSearchValue(
                        MIN_SOFT_DEPTH_LIMIT)};
                    double maxSearchValue{
                        softTreeDepthLimitParameter.toSearchValue(maxSoftDepthLimit)};
                    depthPenaltyMultiplierParameter.set(m_GainPerNode50thPercentile);
                    std::tie(m_LossGap, m_NumberTrees) =
                        hyperparameters
                            .initializeFineTuneSearchInterval(
                                CBoostedTreeHyperparameters::CInitializeFineTuneArguments{
                                    frame, *m_TreeImpl, maxSoftDepthLimit, log2MaxTreeSize,
                                    [](CBoostedTreeImpl& tree, double softDepthLimit) {
                                        auto& parameter =
                                            tree.m_Hyperparameters.softTreeDepthLimit();
                                        parameter.set(parameter.fromSearchValue(softDepthLimit));
                                        return true;
                                    }}
                                    .truncateParameter([&](TVector& range) {
                                        range = truncate(range, minSearchValue, maxSearchValue);
                                    }),
                                softTreeDepthLimitParameter)
                            .value_or(std::make_pair(m_LossGap, m_NumberTrees));
                } else {
                    softTreeDepthLimitParameter.fix();
                }
            })) {
            m_TreeImpl->m_TrainingProgress.increment(
                this->lineSearchMaximumNumberIterations(frame));
        }
    }

    // Update the soft depth tolerance.
    if (softTreeDepthToleranceParameter.rangeFixed() == false) {
        softTreeDepthToleranceParameter.fixToRange(MIN_SOFT_DEPTH_LIMIT_TOLERANCE,
                                                   MAX_SOFT_DEPTH_LIMIT_TOLERANCE);
        softTreeDepthToleranceParameter.set(
            0.5 * (MIN_SOFT_DEPTH_LIMIT_TOLERANCE + MAX_SOFT_DEPTH_LIMIT_TOLERANCE));
    }

    // Search for the depth penalty multipliers at which the model starts
    // to overfit.
    if (depthPenaltyMultiplierParameter.rangeFixed() == false) {
        if (this->skipCheckpointIfAtOrAfter(CBoostedTreeImpl::E_DepthPenaltyMultiplierInitialized, [&] {
                std::tie(m_LossGap, m_NumberTrees) =
                    hyperparameters
                        .initializeFineTuneSearchInterval(
                            CBoostedTreeHyperparameters::CInitializeFineTuneArguments{
                                frame, *m_TreeImpl, m_GainPerNode90thPercentile,
                                2.0 * m_GainPerNode90thPercentile / m_GainPerNode1stPercentile,
                                [](CBoostedTreeImpl& tree, double depthPenalty) {
                                    auto& parameter =
                                        tree.m_Hyperparameters.depthPenaltyMultiplier();
                                    parameter.set(parameter.fromSearchValue(depthPenalty));
                                    return true;
                                }},
                            depthPenaltyMultiplierParameter)
                        .value_or(std::make_pair(m_LossGap, m_NumberTrees));
            })) {
            m_TreeImpl->m_TrainingProgress.increment(
                this->lineSearchMaximumNumberIterations(frame));
        }
    }

    if (depthPenaltyMultiplierParameter.fixed() &&
        depthPenaltyMultiplierParameter.value() == 0.0) {
        // Lock down the depth and tolerance parameters since they have no effect
        // and adjusting them just wastes time.
        softTreeDepthLimitParameter.fix();
        softTreeDepthToleranceParameter.fix();
    }

    // Search for the value of the tree size penalty multiplier at which the
    // model starts to overfit.
    if (treeSizePenaltyMultiplier.rangeFixed() == false) {
        if (this->skipCheckpointIfAtOrAfter(CBoostedTreeImpl::E_TreeSizePenaltyMultiplierInitialized, [&] {
                std::tie(m_LossGap, m_NumberTrees) =
                    hyperparameters
                        .initializeFineTuneSearchInterval(
                            CBoostedTreeHyperparameters::CInitializeFineTuneArguments{
                                frame, *m_TreeImpl, m_GainPerNode90thPercentile,
                                2.0 * m_GainPerNode90thPercentile / m_GainPerNode1stPercentile,
                                [](CBoostedTreeImpl& tree, double treeSizePenalty) {
                                    auto& parameter =
                                        tree.m_Hyperparameters.treeSizePenaltyMultiplier();
                                    parameter.set(parameter.fromSearchValue(treeSizePenalty));
                                    return true;
                                }},
                            treeSizePenaltyMultiplier)
                        .value_or(std::make_pair(m_LossGap, m_NumberTrees));
            })) {
            m_TreeImpl->m_TrainingProgress.increment(
                this->lineSearchMaximumNumberIterations(frame));
        }
    }

    // Search for the value of the leaf weight penalty multiplier at which the
    // model starts to overfit.
    if (leafWeightPenaltyMultiplier.rangeFixed() == false) {
        if (this->skipCheckpointIfAtOrAfter(CBoostedTreeImpl::E_LeafWeightPenaltyMultiplierInitialized, [&] {
                std::tie(m_LossGap, m_NumberTrees) =
                    hyperparameters
                        .initializeFineTuneSearchInterval(
                            CBoostedTreeHyperparameters::CInitializeFineTuneArguments{
                                frame, *m_TreeImpl, m_TotalCurvaturePerNode90thPercentile,
                                2.0 * m_TotalCurvaturePerNode90thPercentile / m_TotalCurvaturePerNode1stPercentile,
                                [](CBoostedTreeImpl& tree, double leafWeightPenalty) {
                                    auto& parameter =
                                        tree.m_Hyperparameters.leafWeightPenaltyMultiplier();
                                    parameter.set(parameter.fromSearchValue(leafWeightPenalty));
                                    return true;
                                }},
                            leafWeightPenaltyMultiplier)
                        .value_or(std::make_pair(m_LossGap, m_NumberTrees));
            })) {
            m_TreeImpl->m_TrainingProgress.increment(
                this->lineSearchMaximumNumberIterations(frame));
        }
    }

    this->initializeUnsetPredictionChangeCost();
    this->initializeUnsetTreeTopologyPenalty(frame);
}