def initialize()

in src/sagemaker_xgboost_container/algorithm_mode/hyperparameter_validation.py [0:0]


def initialize(metrics):
    @hpv.range_validator(["auto", "exact", "approx", "hist", "gpu_hist"])
    def tree_method_range_validator(CATEGORIES, value):
        return value in CATEGORIES

    @hpv.dependencies_validator(["booster", "process_type"])
    def updater_validator(value, dependencies):
        valid_tree_plugins = [
            "grow_colmaker",
            "distcol",
            "grow_histmaker",
            "grow_skmaker",
            "sync",
            "refresh",
            "prune",
            "grow_quantile_histmaker",
        ]
        valid_tree_build_plugins = [
            "grow_colmaker",
            "distcol",
            "grow_histmaker",
            "grow_colmaker",
            "grow_quantile_histmaker",
        ]
        valid_linear_plugins = ["shotgun", "coord_descent"]
        valid_process_update_plugins = ["refresh", "prune"]

        if dependencies.get("booster") == "gblinear":
            # validate only one linear updater is selected
            if not (len(value) == 1 and value[0] in valid_linear_plugins):
                raise exc.UserError(
                    "Linear updater should be one of these options: {}.".format(
                        ", ".join("'{0}'".format(valid_updater for valid_updater in valid_linear_plugins))
                    )
                )
        elif dependencies.get("process_type") == "update":
            if not all(x in valid_process_update_plugins for x in value):
                raise exc.UserError("process_type 'update' can only be used with updater 'refresh' and 'prune'")
        else:
            if not all(x in valid_tree_plugins for x in value):
                raise exc.UserError(
                    "Tree updater should be selected from these options: 'grow_colmaker', 'distcol', 'grow_histmaker', "
                    "'grow_skmaker', 'grow_quantile_histmaker', 'sync', 'refresh', 'prune', "
                    "'shortgun', 'coord_descent'."
                )
            # validate only one tree updater is selected
            counter = 0
            for tmp in value:
                if tmp in valid_tree_build_plugins:
                    counter += 1
            if counter > 1:
                raise exc.UserError(
                    "Only one tree grow plugin can be selected. Choose one from the"
                    "following: 'grow_colmaker', 'distcol', 'grow_histmaker', "
                    "'grow_skmaker'"
                )

    @hpv.range_validator(["auto", "cpu_predictor", "gpu_predictor"])
    def predictor_validator(CATEGORIES, value):
        return value in CATEGORIES

    @hpv.dependencies_validator(["num_class"])
    def objective_validator(value, dependencies):
        num_class = dependencies.get("num_class")
        if value in ("multi:softmax", "multi:softprob") and num_class is None:
            raise exc.UserError("Require input for parameter 'num_class' for multi-classification")
        if value is None and num_class is not None:
            raise exc.UserError(
                "Do not need to setup parameter 'num_class' for learning task other than " "multi-classification."
            )

    @hpv.range_validator(XGB_MAXIMIZE_METRICS + XGB_MINIMIZE_METRICS)
    def eval_metric_range_validator(SUPPORTED_METRIC, metric):
        if "<function" in metric:
            raise exc.UserError("User defined evaluation metric {} is not supported yet.".format(metric))

        if "@" in metric:
            metric_name = metric.split("@")[0].strip()
            metric_threshold = metric.split("@")[1].strip()
            if metric_name not in ["error", "ndcg", "map"]:
                raise exc.UserError(
                    "Metric '{}' is not supported. Parameter 'eval_metric' with customized threshold should "
                    "be one of these options: 'error', 'ndcg', 'map'.".format(metric)
                )
            try:
                float(metric_threshold)
            except ValueError:
                raise exc.UserError("Threshold value 't' in '{}@t' expects float input.".format(metric_name))
            return True

        return metric in SUPPORTED_METRIC

    @hpv.dependencies_validator(["objective"])
    def eval_metric_dep_validator(value, dependencies):
        objective = dependencies["objective"]
        if "auc" in value:
            if not any(objective.startswith(metric_type) for metric_type in ["binary:", "rank:"]):
                raise exc.UserError("Metric 'auc' can only be applied for classification and ranking problems.")
        if "aft-nloglik" in value:
            if objective not in ["survival:aft"]:
                raise exc.UserError("Metric 'aft-nloglik' can only be applied for 'survival:aft' objective.")

    @hpv.dependencies_validator(["tree_method"])
    def monotone_constraints_validator(value, dependencies):
        tree_method = dependencies.get("tree_method")
        if value is not None and tree_method not in ("exact", "hist"):
            raise exc.UserError(
                "monotone_constraints can be used only when the tree_method parameter is set to "
                "either 'exact' or 'hist'."
            )

    @hpv.dependencies_validator(["tree_method"])
    def interaction_constraints_validator(value, dependencies):
        tree_method = dependencies.get("tree_method")
        if value is not None and tree_method not in ("exact", "hist", "approx"):
            raise exc.UserError(
                "interaction_constraints can be used only when the tree_method parameter is set to "
                "either 'exact', 'hist' or 'approx'."
            )

    hyperparameters = hpv.Hyperparameters(
        hpv.IntegerHyperparameter(
            name="num_round",
            required=True,
            range=hpv.Interval(min_closed=1),
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=1, max_closed=4000, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.IntegerHyperparameter(name="csv_weights", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.IntegerHyperparameter(name="early_stopping_rounds", range=hpv.Interval(min_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="booster", range=["gbtree", "gblinear", "dart"], required=False),
        hpv.IntegerHyperparameter(name="verbosity", range=hpv.Interval(min_closed=0, max_closed=3), required=False),
        hpv.IntegerHyperparameter(name="nthread", range=hpv.Interval(min_closed=1), required=False),
        hpv.ContinuousHyperparameter(
            name="eta",
            range=hpv.Interval(min_closed=0, max_closed=1),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=0.5, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="gamma",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=5, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.IntegerHyperparameter(
            name="max_depth",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=10, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="min_child_weight",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=120, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="max_delta_step",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=10, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="subsample",
            range=hpv.Interval(min_open=0, max_closed=1),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="colsample_bytree",
            range=hpv.Interval(min_open=0, max_closed=1),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="colsample_bylevel",
            range=hpv.Interval(min_open=0, max_closed=1),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="colsample_bynode",
            range=hpv.Interval(min_open=0, max_closed=1),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="lambda",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.ContinuousHyperparameter(
            name="alpha",
            range=hpv.Interval(min_closed=0),
            required=False,
            tunable=True,
            tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000, scale=hpv.Interval.LINEAR_SCALE),
        ),
        hpv.CategoricalHyperparameter(name="tree_method", range=tree_method_range_validator, required=False),
        hpv.ContinuousHyperparameter(name="sketch_eps", range=hpv.Interval(min_open=0, max_open=1), required=False),
        hpv.ContinuousHyperparameter(name="scale_pos_weight", range=hpv.Interval(min_open=0), required=False),
        hpv.CommaSeparatedListHyperparameter(
            name="updater",
            range=[
                "grow_colmaker",
                "distcol",
                "grow_histmaker",
                "grow_skmaker",
                "sync",
                "refresh",
                "prune",
                "grow_colmaker",
                "distcol",
                "grow_histmaker",
                "grow_colmaker",
                "shotgun",
                "coord_descent",
                "refresh",
                "prune",
            ],
            dependencies=updater_validator,
            required=False,
        ),
        hpv.CategoricalHyperparameter(name="dsplit", range=["row", "col"], required=False),
        hpv.IntegerHyperparameter(name="refresh_leaf", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="process_type", range=["default", "update"], required=False),
        hpv.CategoricalHyperparameter(name="grow_policy", range=["depthwise", "lossguide"], required=False),
        hpv.IntegerHyperparameter(name="max_leaves", range=hpv.Interval(min_closed=0), required=False),
        hpv.IntegerHyperparameter(name="max_bin", range=hpv.Interval(min_closed=0), required=False),
        hpv.CategoricalHyperparameter(name="predictor", range=predictor_validator, required=False),
        hpv.TupleHyperparameter(
            name="monotone_constraints", range=[-1, 0, 1], required=False, dependencies=monotone_constraints_validator
        ),
        hpv.NestedListHyperparameter(
            name="interaction_constraints",
            range=hpv.Interval(min_closed=1),
            required=False,
            dependencies=interaction_constraints_validator,
        ),
        hpv.CategoricalHyperparameter(name="sample_type", range=["uniform", "weighted"], required=False),
        hpv.CategoricalHyperparameter(name="normalize_type", range=["tree", "forest"], required=False),
        hpv.ContinuousHyperparameter(name="rate_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.IntegerHyperparameter(name="one_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.ContinuousHyperparameter(name="skip_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
        hpv.ContinuousHyperparameter(
            name="lambda_bias", range=hpv.Interval(min_closed=0, max_closed=1), required=False
        ),
        hpv.ContinuousHyperparameter(
            name="tweedie_variance_power", range=hpv.Interval(min_open=1, max_open=2), required=False
        ),
        hpv.CategoricalHyperparameter(
            name="objective",
            range=[
                "aft_loss_distribution",
                "binary:logistic",
                "binary:logitraw",
                "binary:hinge",
                "count:poisson",
                "multi:softmax",
                "multi:softprob",
                "rank:pairwise",
                "rank:ndcg",
                "rank:map",
                "reg:linear",
                "reg:squarederror",
                "reg:logistic",
                "reg:gamma",
                "reg:pseudohubererror",
                "reg:squaredlogerror",
                "reg:absoluteerror",
                "reg:tweedie",
                "survival:aft",
                "survival:cox",
            ],
            dependencies=objective_validator,
            required=False,
        ),
        hpv.IntegerHyperparameter(name="num_class", range=hpv.Interval(min_closed=2), required=False),
        hpv.ContinuousHyperparameter(name="base_score", range=hpv.Interval(min_closed=0), required=False),
        hpv.IntegerHyperparameter(name="_kfold", range=hpv.Interval(min_closed=2), required=False, tunable=False),
        hpv.IntegerHyperparameter(
            name="_num_cv_round", range=hpv.Interval(min_closed=1), required=False, tunable=False
        ),
        hpv.CategoricalHyperparameter(name="_tuning_objective_metric", range=metrics.names, required=False),
        hpv.CommaSeparatedListHyperparameter(
            name="eval_metric",
            range=eval_metric_range_validator,
            dependencies=eval_metric_dep_validator,
            required=False,
        ),
        hpv.IntegerHyperparameter(
            name="seed", range=hpv.Interval(min_open=-(2 ** 31), max_open=2 ** 31 - 1), required=False
        ),
        hpv.IntegerHyperparameter(name="num_parallel_tree", range=hpv.Interval(min_closed=1), required=False),
        hpv.CategoricalHyperparameter(name="save_model_on_termination", range=["true", "false"], required=False),
        hpv.CategoricalHyperparameter(
            name="aft_loss_distribution", range=["normal", "logistic", "extreme"], required=False
        ),
        hpv.ContinuousHyperparameter(
            name="aft_loss_distribution_scale", range=hpv.Interval(min_closed=0), required=False
        ),
        hpv.CategoricalHyperparameter(name="deterministic_histogram", range=["true", "false"], required=False),
        hpv.CategoricalHyperparameter(name="sampling_method", range=["uniform", "gradient_based"], required=False),
        hpv.IntegerHyperparameter(name="prob_buffer_row", range=hpv.Interval(min_open=1.0), required=False),
        # Not an XGB training HP, but is used to determine which distributed training framework to use by SM XGB.
        hpv.CategoricalHyperparameter(name="use_dask_gpu_training", range=["true", "false"], required=False),
    )

    hyperparameters.declare_alias("eta", "learning_rate")
    hyperparameters.declare_alias("gamma", "min_split_loss")
    hyperparameters.declare_alias("lambda", "reg_lambda")
    hyperparameters.declare_alias("alpha", "reg_alpha")

    return hyperparameters