in src/sagemaker/tuner.py [0:0]
def _get_tuner_args(cls, tuner, inputs):
"""Gets a dict of arguments for a new Amazon SageMaker tuning job from the tuner
Args:
tuner (:class:`~sagemaker.tuner.HyperparameterTuner`):
The ``HyperparameterTuner`` instance that started the job.
inputs: Information about the training data. Please refer to the
``fit()`` method of the associated estimator.
Returns:
Dict: dict for `sagemaker.session.Session.tune` method
"""
warm_start_config_req = None
if tuner.warm_start_config:
warm_start_config_req = tuner.warm_start_config.to_input_req()
tuning_config = {
"strategy": tuner.strategy,
"max_jobs": tuner.max_jobs,
"max_parallel_jobs": tuner.max_parallel_jobs,
"early_stopping_type": tuner.early_stopping_type,
}
if tuner.max_runtime_in_seconds is not None:
tuning_config["max_runtime_in_seconds"] = tuner.max_runtime_in_seconds
if tuner.random_seed is not None:
tuning_config["random_seed"] = tuner.random_seed
if tuner.strategy_config is not None:
tuning_config["strategy_config"] = tuner.strategy_config.to_input_req()
if tuner.objective_metric_name is not None:
tuning_config["objective_type"] = tuner.objective_type
tuning_config["objective_metric_name"] = tuner.objective_metric_name
parameter_ranges = tuner.hyperparameter_ranges()
if parameter_ranges is not None:
tuning_config["parameter_ranges"] = parameter_ranges
if tuner.auto_parameters is not None:
tuning_config["auto_parameters"] = tuner.auto_parameters
if tuner.completion_criteria_config is not None:
tuning_config["completion_criteria_config"] = (
tuner.completion_criteria_config.to_input_req()
)
tuner_args = {
"job_name": tuner._current_job_name,
"tuning_config": tuning_config,
"tags": tuner.tags,
"warm_start_config": warm_start_config_req,
"autotune": tuner.autotune,
}
if tuner.estimator is not None:
tuner_args["training_config"] = cls._prepare_training_config(
inputs=inputs,
estimator=tuner.estimator,
static_hyperparameters=tuner.static_hyperparameters,
metric_definitions=tuner.metric_definitions,
instance_configs=tuner.instance_configs,
)
if tuner.estimator_dict is not None:
tuner_args["training_config_list"] = [
cls._prepare_training_config(
inputs.get(estimator_name, None) if inputs is not None else None,
tuner.estimator_dict[estimator_name],
tuner.static_hyperparameters_dict[estimator_name],
tuner.metric_definitions_dict.get(estimator_name, None),
estimator_name,
tuner.objective_type,
tuner.objective_metric_name_dict[estimator_name],
tuner.hyperparameter_ranges_dict()[estimator_name],
(
tuner.instance_configs_dict.get(estimator_name, None)
if tuner.instance_configs_dict is not None
else None
),
(
tuner.auto_parameters_dict.get(estimator_name, None)
if tuner.auto_parameters_dict is not None
else None
),
)
for estimator_name in sorted(tuner.estimator_dict.keys())
]
return tuner_args