bool CBoostedTreeImpl::acceptRestoreTraverser()

in lib/maths/analytics/CBoostedTreeImpl.cc [2517:2626]


bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
    if (std::find(SUPPORTED_VERSIONS.begin(), SUPPORTED_VERSIONS.end(),
                  traverser.name()) == SUPPORTED_VERSIONS.end()) {
        LOG_ERROR(<< "Input error: unsupported state serialization version. "
                  << "Currently supported versions: " << SUPPORTED_VERSIONS << ".");
        return false;
    }

    auto restoreLoss = [this](core::CStateRestoreTraverser& traverser_) {
        m_Loss = CLoss::restoreLoss(traverser_);
        return m_Loss != nullptr;
    };

    int initializationStage{static_cast<int>(E_FullyInitialized)};

    do {
        const std::string& name{traverser.name()};
        RESTORE(BEST_FOREST_TAG,
                core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
        RESTORE_SETUP_TEARDOWN(
            CLASSIFICATION_WEIGHTS_OVERRIDE_TAG,
            m_ClassificationWeightsOverride = TStrDoublePrVec{},
            core::CPersistUtils::restore(CLASSIFICATION_WEIGHTS_OVERRIDE_TAG,
                                         *m_ClassificationWeightsOverride, traverser),
            /*no-op*/)
        RESTORE(DATA_SUMMARIZATION_FRACTION_TAG,
                core::CPersistUtils::restore(DATA_SUMMARIZATION_FRACTION_TAG,
                                             m_DataSummarizationFraction, traverser))
        RESTORE(DEPENDENT_VARIABLE_TAG,
                core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
                                             m_DependentVariable, traverser))
        RESTORE_NO_ERROR(ENCODER_TAG,
                         m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
        RESTORE(FEATURE_DATA_TYPES_TAG,
                core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
                                             m_FeatureDataTypes, traverser));
        RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
                core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
                                             m_FeatureSampleProbabilities, traverser))
        RESTORE(FOLD_ROUND_TEST_LOSSES_TAG,
                core::CPersistUtils::restore(FOLD_ROUND_TEST_LOSSES_TAG,
                                             m_FoldRoundTestLosses, traverser))
        RESTORE(FORCE_ACCEPT_INCREMENTAL_TRAINING_TAG,
                core::CPersistUtils::restore(FORCE_ACCEPT_INCREMENTAL_TRAINING_TAG,
                                             m_ForceAcceptIncrementalTraining, traverser))
        RESTORE(HYPERPARAMETERS_TAG,
                core::CPersistUtils::restore(HYPERPARAMETERS_TAG, m_Hyperparameters, traverser))
        RESTORE(INITIALIZATION_STAGE_TAG,
                core::CPersistUtils::restore(INITIALIZATION_STAGE_TAG,
                                             initializationStage, traverser))
        RESTORE(LOSS_TAG, traverser.traverseSubLevel(restoreLoss))
        RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
                core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
                                             m_MaximumAttemptsToAddTree, traverser))
        RESTORE(MAXIMUM_NUMBER_NEW_TREES_TAG,
                core::CPersistUtils::restore(MAXIMUM_NUMBER_NEW_TREES_TAG,
                                             m_MaximumNumberNewTrees, traverser))
        RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
                core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
                                             m_MissingFeatureRowMasks, traverser))
        RESTORE(NEW_TRAINING_ROW_MASK_TAG,
                core::CPersistUtils::restore(NEW_TRAINING_ROW_MASK_TAG,
                                             m_NewTrainingRowMask, traverser))
        RESTORE(NUMBER_FOLDS_TAG,
                core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
        RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
                core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
                                             m_NumberSplitsPerFeature, traverser))
        RESTORE(NUMBER_THREADS_TAG,
                core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
        RESTORE(NUMBER_TOP_SHAP_VALUES_TAG,
                core::CPersistUtils::restore(NUMBER_TOP_SHAP_VALUES_TAG,
                                             m_NumberTopShapValues, traverser))
        RESTORE(PREVIOUS_TRAIN_LOSS_GAP_TAG,
                core::CPersistUtils::restore(PREVIOUS_TRAIN_LOSS_GAP_TAG,
                                             m_PreviousTrainLossGap, traverser))
        RESTORE(PREVIOUS_TRAIN_NUMBER_ROWS_TAG,
                core::CPersistUtils::restore(PREVIOUS_TRAIN_NUMBER_ROWS_TAG,
                                             m_PreviousTrainNumberRows, traverser))
        RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
        RESTORE(RETRAIN_FRACTION_TAG,
                core::CPersistUtils::restore(RETRAIN_FRACTION_TAG, m_RetrainFraction, traverser))
        RESTORE(ROWS_PER_FEATURE_TAG,
                core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
        RESTORE(SEED_TAG, core::CPersistUtils::restore(SEED_TAG, m_Seed, traverser))
        RESTORE(STOP_CROSS_VALIDATION_EARLY_TAG,
                core::CPersistUtils::restore(STOP_CROSS_VALIDATION_EARLY_TAG,
                                             m_StopCrossValidationEarly, traverser))
        RESTORE(TESTING_ROW_MASKS_TAG,
                core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG, m_TestingRowMasks, traverser))
        RESTORE(TRAINING_ROW_MASKS_TAG,
                core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG, m_TrainingRowMasks, traverser))
        RESTORE(TRAIN_FRACTION_PER_FOLD_TAG,
                core::CPersistUtils::restore(TRAIN_FRACTION_PER_FOLD_TAG,
                                             m_TrainFractionPerFold, traverser))
        RESTORE(TREES_TO_RETRAIN_TAG,
                core::CPersistUtils::restore(TREES_TO_RETRAIN_TAG, m_TreesToRetrain, traverser))
        RESTORE(USER_SUPPLIED_HOLD_OUT_SET_TAG,
                core::CPersistUtils::restore(USER_SUPPLIED_HOLD_OUT_SET_TAG,
                                             m_UserSuppliedHoldOutSet, traverser))
    } while (traverser.next());

    // Extra column information is recreated when training state is restored.

    m_InitializationStage = static_cast<EInitializationStage>(initializationStage);

    this->checkRestoredInvariants();

    return true;
}