in src/sagemaker_xgboost_container/algorithm_mode/hyperparameter_validation.py [0:0]
def initialize(metrics):
@hpv.range_validator(["auto", "exact", "approx", "hist", "gpu_hist"])
def tree_method_range_validator(CATEGORIES, value):
return value in CATEGORIES
@hpv.dependencies_validator(["booster", "process_type"])
def updater_validator(value, dependencies):
valid_tree_plugins = [
"grow_colmaker",
"distcol",
"grow_histmaker",
"grow_skmaker",
"sync",
"refresh",
"prune",
"grow_quantile_histmaker",
]
valid_tree_build_plugins = [
"grow_colmaker",
"distcol",
"grow_histmaker",
"grow_colmaker",
"grow_quantile_histmaker",
]
valid_linear_plugins = ["shotgun", "coord_descent"]
valid_process_update_plugins = ["refresh", "prune"]
if dependencies.get("booster") == "gblinear":
# validate only one linear updater is selected
if not (len(value) == 1 and value[0] in valid_linear_plugins):
raise exc.UserError(
"Linear updater should be one of these options: {}.".format(
", ".join("'{0}'".format(valid_updater for valid_updater in valid_linear_plugins))
)
)
elif dependencies.get("process_type") == "update":
if not all(x in valid_process_update_plugins for x in value):
raise exc.UserError("process_type 'update' can only be used with updater 'refresh' and 'prune'")
else:
if not all(x in valid_tree_plugins for x in value):
raise exc.UserError(
"Tree updater should be selected from these options: 'grow_colmaker', 'distcol', 'grow_histmaker', "
"'grow_skmaker', 'grow_quantile_histmaker', 'sync', 'refresh', 'prune', "
"'shortgun', 'coord_descent'."
)
# validate only one tree updater is selected
counter = 0
for tmp in value:
if tmp in valid_tree_build_plugins:
counter += 1
if counter > 1:
raise exc.UserError(
"Only one tree grow plugin can be selected. Choose one from the"
"following: 'grow_colmaker', 'distcol', 'grow_histmaker', "
"'grow_skmaker'"
)
@hpv.range_validator(["auto", "cpu_predictor", "gpu_predictor"])
def predictor_validator(CATEGORIES, value):
return value in CATEGORIES
@hpv.dependencies_validator(["num_class"])
def objective_validator(value, dependencies):
num_class = dependencies.get("num_class")
if value in ("multi:softmax", "multi:softprob") and num_class is None:
raise exc.UserError("Require input for parameter 'num_class' for multi-classification")
if value is None and num_class is not None:
raise exc.UserError(
"Do not need to setup parameter 'num_class' for learning task other than " "multi-classification."
)
@hpv.range_validator(XGB_MAXIMIZE_METRICS + XGB_MINIMIZE_METRICS)
def eval_metric_range_validator(SUPPORTED_METRIC, metric):
if "<function" in metric:
raise exc.UserError("User defined evaluation metric {} is not supported yet.".format(metric))
if "@" in metric:
metric_name = metric.split("@")[0].strip()
metric_threshold = metric.split("@")[1].strip()
if metric_name not in ["error", "ndcg", "map"]:
raise exc.UserError(
"Metric '{}' is not supported. Parameter 'eval_metric' with customized threshold should "
"be one of these options: 'error', 'ndcg', 'map'.".format(metric)
)
try:
float(metric_threshold)
except ValueError:
raise exc.UserError("Threshold value 't' in '{}@t' expects float input.".format(metric_name))
return True
return metric in SUPPORTED_METRIC
@hpv.dependencies_validator(["objective"])
def eval_metric_dep_validator(value, dependencies):
objective = dependencies["objective"]
if "auc" in value:
if not any(objective.startswith(metric_type) for metric_type in ["binary:", "rank:"]):
raise exc.UserError("Metric 'auc' can only be applied for classification and ranking problems.")
if "aft-nloglik" in value:
if objective not in ["survival:aft"]:
raise exc.UserError("Metric 'aft-nloglik' can only be applied for 'survival:aft' objective.")
@hpv.dependencies_validator(["tree_method"])
def monotone_constraints_validator(value, dependencies):
tree_method = dependencies.get("tree_method")
if value is not None and tree_method not in ("exact", "hist"):
raise exc.UserError(
"monotone_constraints can be used only when the tree_method parameter is set to "
"either 'exact' or 'hist'."
)
@hpv.dependencies_validator(["tree_method"])
def interaction_constraints_validator(value, dependencies):
tree_method = dependencies.get("tree_method")
if value is not None and tree_method not in ("exact", "hist", "approx"):
raise exc.UserError(
"interaction_constraints can be used only when the tree_method parameter is set to "
"either 'exact', 'hist' or 'approx'."
)
hyperparameters = hpv.Hyperparameters(
hpv.IntegerHyperparameter(
name="num_round",
required=True,
range=hpv.Interval(min_closed=1),
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=1, max_closed=4000, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.IntegerHyperparameter(name="csv_weights", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
hpv.IntegerHyperparameter(name="early_stopping_rounds", range=hpv.Interval(min_closed=1), required=False),
hpv.CategoricalHyperparameter(name="booster", range=["gbtree", "gblinear", "dart"], required=False),
hpv.IntegerHyperparameter(name="verbosity", range=hpv.Interval(min_closed=0, max_closed=3), required=False),
hpv.IntegerHyperparameter(name="nthread", range=hpv.Interval(min_closed=1), required=False),
hpv.ContinuousHyperparameter(
name="eta",
range=hpv.Interval(min_closed=0, max_closed=1),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=0.5, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="gamma",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=5, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.IntegerHyperparameter(
name="max_depth",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=10, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="min_child_weight",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=120, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="max_delta_step",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=10, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="subsample",
range=hpv.Interval(min_open=0, max_closed=1),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="colsample_bytree",
range=hpv.Interval(min_open=0, max_closed=1),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0.5, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="colsample_bylevel",
range=hpv.Interval(min_open=0, max_closed=1),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="colsample_bynode",
range=hpv.Interval(min_open=0, max_closed=1),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0.1, max_closed=1, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="lambda",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.ContinuousHyperparameter(
name="alpha",
range=hpv.Interval(min_closed=0),
required=False,
tunable=True,
tunable_recommended_range=hpv.Interval(min_closed=0, max_closed=1000, scale=hpv.Interval.LINEAR_SCALE),
),
hpv.CategoricalHyperparameter(name="tree_method", range=tree_method_range_validator, required=False),
hpv.ContinuousHyperparameter(name="sketch_eps", range=hpv.Interval(min_open=0, max_open=1), required=False),
hpv.ContinuousHyperparameter(name="scale_pos_weight", range=hpv.Interval(min_open=0), required=False),
hpv.CommaSeparatedListHyperparameter(
name="updater",
range=[
"grow_colmaker",
"distcol",
"grow_histmaker",
"grow_skmaker",
"sync",
"refresh",
"prune",
"grow_colmaker",
"distcol",
"grow_histmaker",
"grow_colmaker",
"shotgun",
"coord_descent",
"refresh",
"prune",
],
dependencies=updater_validator,
required=False,
),
hpv.CategoricalHyperparameter(name="dsplit", range=["row", "col"], required=False),
hpv.IntegerHyperparameter(name="refresh_leaf", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
hpv.CategoricalHyperparameter(name="process_type", range=["default", "update"], required=False),
hpv.CategoricalHyperparameter(name="grow_policy", range=["depthwise", "lossguide"], required=False),
hpv.IntegerHyperparameter(name="max_leaves", range=hpv.Interval(min_closed=0), required=False),
hpv.IntegerHyperparameter(name="max_bin", range=hpv.Interval(min_closed=0), required=False),
hpv.CategoricalHyperparameter(name="predictor", range=predictor_validator, required=False),
hpv.TupleHyperparameter(
name="monotone_constraints", range=[-1, 0, 1], required=False, dependencies=monotone_constraints_validator
),
hpv.NestedListHyperparameter(
name="interaction_constraints",
range=hpv.Interval(min_closed=1),
required=False,
dependencies=interaction_constraints_validator,
),
hpv.CategoricalHyperparameter(name="sample_type", range=["uniform", "weighted"], required=False),
hpv.CategoricalHyperparameter(name="normalize_type", range=["tree", "forest"], required=False),
hpv.ContinuousHyperparameter(name="rate_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
hpv.IntegerHyperparameter(name="one_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
hpv.ContinuousHyperparameter(name="skip_drop", range=hpv.Interval(min_closed=0, max_closed=1), required=False),
hpv.ContinuousHyperparameter(
name="lambda_bias", range=hpv.Interval(min_closed=0, max_closed=1), required=False
),
hpv.ContinuousHyperparameter(
name="tweedie_variance_power", range=hpv.Interval(min_open=1, max_open=2), required=False
),
hpv.CategoricalHyperparameter(
name="objective",
range=[
"aft_loss_distribution",
"binary:logistic",
"binary:logitraw",
"binary:hinge",
"count:poisson",
"multi:softmax",
"multi:softprob",
"rank:pairwise",
"rank:ndcg",
"rank:map",
"reg:linear",
"reg:squarederror",
"reg:logistic",
"reg:gamma",
"reg:pseudohubererror",
"reg:squaredlogerror",
"reg:absoluteerror",
"reg:tweedie",
"survival:aft",
"survival:cox",
],
dependencies=objective_validator,
required=False,
),
hpv.IntegerHyperparameter(name="num_class", range=hpv.Interval(min_closed=2), required=False),
hpv.ContinuousHyperparameter(name="base_score", range=hpv.Interval(min_closed=0), required=False),
hpv.IntegerHyperparameter(name="_kfold", range=hpv.Interval(min_closed=2), required=False, tunable=False),
hpv.IntegerHyperparameter(
name="_num_cv_round", range=hpv.Interval(min_closed=1), required=False, tunable=False
),
hpv.CategoricalHyperparameter(name="_tuning_objective_metric", range=metrics.names, required=False),
hpv.CommaSeparatedListHyperparameter(
name="eval_metric",
range=eval_metric_range_validator,
dependencies=eval_metric_dep_validator,
required=False,
),
hpv.IntegerHyperparameter(
name="seed", range=hpv.Interval(min_open=-(2 ** 31), max_open=2 ** 31 - 1), required=False
),
hpv.IntegerHyperparameter(name="num_parallel_tree", range=hpv.Interval(min_closed=1), required=False),
hpv.CategoricalHyperparameter(name="save_model_on_termination", range=["true", "false"], required=False),
hpv.CategoricalHyperparameter(
name="aft_loss_distribution", range=["normal", "logistic", "extreme"], required=False
),
hpv.ContinuousHyperparameter(
name="aft_loss_distribution_scale", range=hpv.Interval(min_closed=0), required=False
),
hpv.CategoricalHyperparameter(name="deterministic_histogram", range=["true", "false"], required=False),
hpv.CategoricalHyperparameter(name="sampling_method", range=["uniform", "gradient_based"], required=False),
hpv.IntegerHyperparameter(name="prob_buffer_row", range=hpv.Interval(min_open=1.0), required=False),
# Not an XGB training HP, but is used to determine which distributed training framework to use by SM XGB.
hpv.CategoricalHyperparameter(name="use_dask_gpu_training", range=["true", "false"], required=False),
)
hyperparameters.declare_alias("eta", "learning_rate")
hyperparameters.declare_alias("gamma", "min_split_loss")
hyperparameters.declare_alias("lambda", "reg_lambda")
hyperparameters.declare_alias("alpha", "reg_alpha")
return hyperparameters