def initialize()

in src/sagemaker_xgboost_container/algorithm_mode/hyperparameter_validation.py [0:0]


def initialize(metrics):
    @hpv.range_validator(["auto", "exact", "approx", "hist", "gpu_hist"])
    def tree_method_range_validator(CATEGORIES, value):
        return value in CATEGORIES

    @hpv.dependencies_validator(["booster", "process_type"])
    def updater_validator(value, dependencies):
        valid_tree_plugins = ['grow_colmaker', 'distcol', 'grow_histmaker', 'grow_local_histmaker',
                              'grow_skmaker', 'sync', 'refresh', 'prune', 'grow_quantile_histmaker']
        valid_tree_build_plugins = ['grow_colmaker', 'distcol', 'grow_histmaker',
                                    'grow_local_histmaker', 'grow_colmaker', 'grow_quantile_histmaker']
        valid_linear_plugins = ['shotgun', 'coord_descent']
        valid_process_update_plugins = ['refresh', 'prune']

        if dependencies.get('booster') == 'gblinear':
            # validate only one linear updater is selected
            if not (len(value) == 1 and value[0] in valid_linear_plugins):
                raise exc.UserError("Linear updater should be one of these options: {}.".format(
                    ', '.join("'{0}'".format(valid_updater for valid_updater in valid_linear_plugins))
                ))
        elif dependencies.get('process_type') == 'update':
            if not all(x in valid_process_update_plugins for x in value):
                raise exc.UserError("process_type 'update' can only be used with updater 'refresh' and 'prune'")
        else:
            if not all(x in valid_tree_plugins for x in value):
                raise exc.UserError(
                    "Tree updater should be selected from these options: 'grow_colmaker', 'distcol', 'grow_histmaker', "
                    "'grow_local_histmaker', 'grow_skmaker', 'grow_quantile_histmaker', 'sync', 'refresh', 'prune', "
                    "'shortgun', 'coord_descent'.")
            # validate only one tree updater is selected
            counter = 0
            for tmp in value:
                if tmp in valid_tree_build_plugins:
                    counter += 1
            if counter > 1:
                raise exc.UserError("Only one tree grow plugin can be selected. Choose one from the"
                                    "following: 'grow_colmaker', 'distcol', 'grow_histmaker', "
                                    "'grow_local_histmaker', 'grow_skmaker'")

    @hpv.range_validator(["auto", "cpu_predictor", "gpu_predictor"])
    def predictor_validator(CATEGORIES, value):
        return value in CATEGORIES

    @hpv.dependencies_validator(["num_class"])
    def objective_validator(value, dependencies):
        num_class = dependencies.get("num_class")
        if value in ("multi:softmax", "multi:softprob") and num_class is None:
            raise exc.UserError("Require input for parameter 'num_class' for multi-classification")
        if value is None and num_class is not None:
            raise exc.UserError("Do not need to setup parameter 'num_class' for learning task other than "
                                "multi-classification.")

    @hpv.range_validator(XGB_MAXIMIZE_METRICS + XGB_MINIMIZE_METRICS)
    def eval_metric_range_validator(SUPPORTED_METRIC, metric):
        if "<function" in metric:
            raise exc.UserError("User defined evaluation metric {} is not supported yet.".format(metric))

        if "@" in metric:
            metric_name = metric.split('@')[0].strip()
            metric_threshold = metric.split('@')[1].strip()
            if metric_name not in ["error", "ndcg", "map"]:
                raise exc.UserError(
                    "Metric '{}' is not supported. Parameter 'eval_metric' with customized threshold should "
                    "be one of these options: 'error', 'ndcg', 'map'.".format(metric))
            try:
                float(metric_threshold)
            except ValueError:
                raise exc.UserError("Threshold value 't' in '{}@t' expects float input.".format(metric_name))
            return True

        return metric in SUPPORTED_METRIC

    @hpv.dependencies_validator(["objective"])
    def eval_metric_dep_validator(value, dependencies):
        objective = dependencies["objective"]
        if "auc" in value:
            if not any(objective.startswith(metric_type) for metric_type in ['binary:', 'rank:']):
                raise exc.UserError("Metric 'auc' can only be applied for classification and ranking problems.")
        if "aft-nloglik" in value:
            if objective not in ["survival:aft"]:
                raise exc.UserError("Metric 'aft-nloglik' can only be applied for 'survival:aft' objective.")

    @hpv.dependencies_validator(["tree_method"])
    def monotone_constraints_validator(value, dependencies):
        tree_method = dependencies.get("tree_method")
        if value is not None and tree_method not in ("exact", "hist"):
            raise exc.UserError("monotone_constraints can be used only when the tree_method parameter is set to "
                                "either 'exact' or 'hist'.")

    @hpv.dependencies_validator(["tree_method"])
    def interaction_constraints_validator(value, dependencies):
        tree_method = dependencies.get("tree_method")
        if value is not None and tree_method not in ("exact", "hist", "approx"):
            raise exc.UserError("interaction_constraints can be used only when the tree_method parameter is set to "
                                "either 'exact', 'hist' or 'approx'.")

    hyperparameters = hpv.Hyperparameters(
        hpv.IntegerHyperparameter(name="num_round", required=True,
                                  range=hpv.Interval(min_closed=1),
                                  tunable=True, tunable_recommended_range=hpv.Interval(
                                      min_closed=1,
                                      max_closed=4000,
                                      scale=hpv.Interval.LINEAR_SCALE)),
        hpv.IntegerHyperparameter(name="csv_weights", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.IntegerHyperparameter(name="early_stopping_rounds", range=hpv.Interval(min_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="booster", range=["gbtree", "gblinear", "dart"], required=False),
        hpv.IntegerHyperparameter(name="verbosity", range=hpv.Interval(min_closed=0, max_closed=3), required=False),
        hpv.IntegerHyperparameter(name="nthread", range=hpv.Interval(min_closed=1), required=False),
        hpv.ContinuousHyperparameter(name="eta", range=hpv.Interval(min_closed=0, max_closed=1), required=False,
                                     tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=0.5,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="gamma", range=hpv.Interval(min_closed=0), required=False,
                                     tunable=True, tunable_recommended_range=hpv.Interval(
                                         min_closed=0, max_closed=5,
                                         scale=hpv.Interval.LINEAR_SCALE)),
        hpv.IntegerHyperparameter(name="max_depth", range=hpv.Interval(min_closed=0), required=False,
                                  tunable=True, tunable_recommended_range=hpv.Interval(
                                      min_closed=0, max_closed=10,
                                      scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="min_child_weight", range=hpv.Interval(min_closed=0), required=False,
                                     tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=120,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="max_delta_step", range=hpv.Interval(min_closed=0), required=False,
                                     tunable=True, tunable_recommended_range=hpv.Interval(
                                         min_closed=0, max_closed=10,
                                         scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="subsample", range=hpv.Interval(min_open=0, max_closed=1), required=False,
                                     tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="colsample_bytree", range=hpv.Interval(min_open=0, max_closed=1),
                                     required=False, tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="colsample_bylevel", range=hpv.Interval(min_open=0, max_closed=1),
                                     required=False, tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="colsample_bynode", range=hpv.Interval(min_open=0, max_closed=1),
                                     required=False, tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="lambda", range=hpv.Interval(min_closed=0), required=False,
                                     tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.ContinuousHyperparameter(name="alpha", range=hpv.Interval(min_closed=0), required=False,
                                     tunable=True,
                                     tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000,
                                                                            scale=hpv.Interval.LINEAR_SCALE)),
        hpv.CategoricalHyperparameter(name="tree_method", range=tree_method_range_validator, required=False),
        hpv.ContinuousHyperparameter(name="sketch_eps", range=hpv.Interval(min_open=0, max_open=1), required=False),
        hpv.ContinuousHyperparameter(name="scale_pos_weight", range=hpv.Interval(min_open=0), required=False),
        hpv.CommaSeparatedListHyperparameter(name="updater",
                                             range=['grow_colmaker', 'distcol',
                                                    'grow_histmaker', 'grow_local_histmaker',
                                                    'grow_skmaker', 'sync', 'refresh', 'prune',
                                                    'grow_colmaker', 'distcol', 'grow_histmaker',
                                                    'grow_local_histmaker', 'grow_colmaker',
                                                    'shotgun', 'coord_descent',
                                                    'refresh', 'prune'],
                                             dependencies=updater_validator,
                                             required=False),
        hpv.CategoricalHyperparameter(name="dsplit", range=["row", "col"], required=False),
        hpv.IntegerHyperparameter(name="refresh_leaf", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="process_type", range=["default", "update"], required=False),
        hpv.CategoricalHyperparameter(name="grow_policy", range=["depthwise", "lossguide"], required=False),
        hpv.IntegerHyperparameter(name="max_leaves", range=hpv.Interval(min_closed=0), required=False),
        hpv.IntegerHyperparameter(name="max_bin", range=hpv.Interval(min_closed=0), required=False),
        hpv.CategoricalHyperparameter(name="predictor", range=predictor_validator, required=False),
        hpv.TupleHyperparameter(name="monotone_constraints", range=[-1, 0, 1], required=False,
                                dependencies=monotone_constraints_validator),
        hpv.NestedListHyperparameter(name="interaction_constraints", range=hpv.Interval(min_closed=1), required=False,
                                     dependencies=interaction_constraints_validator),
        hpv.CategoricalHyperparameter(name="sample_type", range=["uniform", "weighted"], required=False),
        hpv.CategoricalHyperparameter(name="normalize_type", range=["tree", "forest"], required=False),
        hpv.ContinuousHyperparameter(name="rate_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.IntegerHyperparameter(name="one_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.ContinuousHyperparameter(name="skip_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.ContinuousHyperparameter(name="lambda_bias", range=hpv.Interval(min_closed=0, max_closed=1),
                                     required=False),
        hpv.ContinuousHyperparameter(name="tweedie_variance_power", range=hpv.Interval(min_open=1, max_open=2),
                                     required=False),
        hpv.CategoricalHyperparameter(name="objective",
                                      range=["aft_loss_distribution", "binary:logistic", "binary:logitraw",
                                             "binary:hinge", "count:poisson", "multi:softmax", "multi:softprob",
                                             "rank:pairwise", "rank:ndcg", "rank:map", "reg:linear",
                                             "reg:squarederror", "reg:logistic", "reg:gamma", "reg:pseudohubererror",
                                             "reg:squaredlogerror", "reg:tweedie", "survival:aft", "survival:cox"],
                                      dependencies=objective_validator,
                                      required=False),
        hpv.IntegerHyperparameter(name="num_class",
                                  range=hpv.Interval(min_closed=2),
                                  required=False),
        hpv.ContinuousHyperparameter(name="base_score", range=hpv.Interval(min_closed=0), required=False),
        hpv.IntegerHyperparameter(name="_kfold", range=hpv.Interval(min_closed=2), required=False, tunable=False),
        hpv.IntegerHyperparameter(name="_num_cv_round", range=hpv.Interval(min_closed=1),
                                  required=False, tunable=False),
        hpv.CategoricalHyperparameter(name="_tuning_objective_metric", range=metrics.names, required=False),
        hpv.CommaSeparatedListHyperparameter(name="eval_metric",
                                             range=eval_metric_range_validator,
                                             dependencies=eval_metric_dep_validator,
                                             required=False),
        hpv.IntegerHyperparameter(name="seed", range=hpv.Interval(min_open=-2**31, max_open=2**31-1),
                                  required=False),
        hpv.IntegerHyperparameter(name="num_parallel_tree", range=hpv.Interval(min_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="save_model_on_termination", range=["true", "false"], required=False),
        hpv.CategoricalHyperparameter(name="aft_loss_distribution", range=["normal", "logistic", "extreme"],
                                      required=False),
        hpv.ContinuousHyperparameter(name="aft_loss_distribution_scale", range=hpv.Interval(min_closed=0),
                                     required=False),
        hpv.CategoricalHyperparameter(name="single_precision_histogram", range=["true", "false"], required=False),
        hpv.CategoricalHyperparameter(name="deterministic_histogram", range=["true", "false"], required=False),
        )

    hyperparameters.declare_alias("eta", "learning_rate")
    hyperparameters.declare_alias("gamma", "min_split_loss")
    hyperparameters.declare_alias("lambda", "reg_lambda")
    hyperparameters.declare_alias("alpha", "reg_alpha")

    return hyperparameters