in lib/maths/analytics/CBoostedTreeImpl.cc [2517:2626]
bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
if (std::find(SUPPORTED_VERSIONS.begin(), SUPPORTED_VERSIONS.end(),
traverser.name()) == SUPPORTED_VERSIONS.end()) {
LOG_ERROR(<< "Input error: unsupported state serialization version. "
<< "Currently supported versions: " << SUPPORTED_VERSIONS << ".");
return false;
}
auto restoreLoss = [this](core::CStateRestoreTraverser& traverser_) {
m_Loss = CLoss::restoreLoss(traverser_);
return m_Loss != nullptr;
};
int initializationStage{static_cast<int>(E_FullyInitialized)};
do {
const std::string& name{traverser.name()};
RESTORE(BEST_FOREST_TAG,
core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
RESTORE_SETUP_TEARDOWN(
CLASSIFICATION_WEIGHTS_OVERRIDE_TAG,
m_ClassificationWeightsOverride = TStrDoublePrVec{},
core::CPersistUtils::restore(CLASSIFICATION_WEIGHTS_OVERRIDE_TAG,
*m_ClassificationWeightsOverride, traverser),
/*no-op*/)
RESTORE(DATA_SUMMARIZATION_FRACTION_TAG,
core::CPersistUtils::restore(DATA_SUMMARIZATION_FRACTION_TAG,
m_DataSummarizationFraction, traverser))
RESTORE(DEPENDENT_VARIABLE_TAG,
core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
m_DependentVariable, traverser))
RESTORE_NO_ERROR(ENCODER_TAG,
m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
RESTORE(FEATURE_DATA_TYPES_TAG,
core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
m_FeatureDataTypes, traverser));
RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
m_FeatureSampleProbabilities, traverser))
RESTORE(FOLD_ROUND_TEST_LOSSES_TAG,
core::CPersistUtils::restore(FOLD_ROUND_TEST_LOSSES_TAG,
m_FoldRoundTestLosses, traverser))
RESTORE(FORCE_ACCEPT_INCREMENTAL_TRAINING_TAG,
core::CPersistUtils::restore(FORCE_ACCEPT_INCREMENTAL_TRAINING_TAG,
m_ForceAcceptIncrementalTraining, traverser))
RESTORE(HYPERPARAMETERS_TAG,
core::CPersistUtils::restore(HYPERPARAMETERS_TAG, m_Hyperparameters, traverser))
RESTORE(INITIALIZATION_STAGE_TAG,
core::CPersistUtils::restore(INITIALIZATION_STAGE_TAG,
initializationStage, traverser))
RESTORE(LOSS_TAG, traverser.traverseSubLevel(restoreLoss))
RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
m_MaximumAttemptsToAddTree, traverser))
RESTORE(MAXIMUM_NUMBER_NEW_TREES_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_NEW_TREES_TAG,
m_MaximumNumberNewTrees, traverser))
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
m_MissingFeatureRowMasks, traverser))
RESTORE(NEW_TRAINING_ROW_MASK_TAG,
core::CPersistUtils::restore(NEW_TRAINING_ROW_MASK_TAG,
m_NewTrainingRowMask, traverser))
RESTORE(NUMBER_FOLDS_TAG,
core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
m_NumberSplitsPerFeature, traverser))
RESTORE(NUMBER_THREADS_TAG,
core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
RESTORE(NUMBER_TOP_SHAP_VALUES_TAG,
core::CPersistUtils::restore(NUMBER_TOP_SHAP_VALUES_TAG,
m_NumberTopShapValues, traverser))
RESTORE(PREVIOUS_TRAIN_LOSS_GAP_TAG,
core::CPersistUtils::restore(PREVIOUS_TRAIN_LOSS_GAP_TAG,
m_PreviousTrainLossGap, traverser))
RESTORE(PREVIOUS_TRAIN_NUMBER_ROWS_TAG,
core::CPersistUtils::restore(PREVIOUS_TRAIN_NUMBER_ROWS_TAG,
m_PreviousTrainNumberRows, traverser))
RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
RESTORE(RETRAIN_FRACTION_TAG,
core::CPersistUtils::restore(RETRAIN_FRACTION_TAG, m_RetrainFraction, traverser))
RESTORE(ROWS_PER_FEATURE_TAG,
core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
RESTORE(SEED_TAG, core::CPersistUtils::restore(SEED_TAG, m_Seed, traverser))
RESTORE(STOP_CROSS_VALIDATION_EARLY_TAG,
core::CPersistUtils::restore(STOP_CROSS_VALIDATION_EARLY_TAG,
m_StopCrossValidationEarly, traverser))
RESTORE(TESTING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG, m_TestingRowMasks, traverser))
RESTORE(TRAINING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG, m_TrainingRowMasks, traverser))
RESTORE(TRAIN_FRACTION_PER_FOLD_TAG,
core::CPersistUtils::restore(TRAIN_FRACTION_PER_FOLD_TAG,
m_TrainFractionPerFold, traverser))
RESTORE(TREES_TO_RETRAIN_TAG,
core::CPersistUtils::restore(TREES_TO_RETRAIN_TAG, m_TreesToRetrain, traverser))
RESTORE(USER_SUPPLIED_HOLD_OUT_SET_TAG,
core::CPersistUtils::restore(USER_SUPPLIED_HOLD_OUT_SET_TAG,
m_UserSuppliedHoldOutSet, traverser))
} while (traverser.next());
// Extra column information is recreated when training state is restored.
m_InitializationStage = static_cast<EInitializationStage>(initializationStage);
this->checkRestoredInvariants();
return true;
}