def _get_tuner_args()

in src/sagemaker/tuner.py [0:0]


    def _get_tuner_args(cls, tuner, inputs):
        """Gets a dict of arguments for a new Amazon SageMaker tuning job from the tuner

        Args:
            tuner (:class:`~sagemaker.tuner.HyperparameterTuner`):
                The ``HyperparameterTuner`` instance that started the job.
            inputs: Information about the training data. Please refer to the
            ``fit()`` method of the associated estimator.
        Returns:
            Dict: dict for `sagemaker.session.Session.tune` method
        """
        warm_start_config_req = None
        if tuner.warm_start_config:
            warm_start_config_req = tuner.warm_start_config.to_input_req()

        tuning_config = {
            "strategy": tuner.strategy,
            "max_jobs": tuner.max_jobs,
            "max_parallel_jobs": tuner.max_parallel_jobs,
            "early_stopping_type": tuner.early_stopping_type,
        }

        if tuner.max_runtime_in_seconds is not None:
            tuning_config["max_runtime_in_seconds"] = tuner.max_runtime_in_seconds

        if tuner.random_seed is not None:
            tuning_config["random_seed"] = tuner.random_seed

        if tuner.strategy_config is not None:
            tuning_config["strategy_config"] = tuner.strategy_config.to_input_req()

        if tuner.objective_metric_name is not None:
            tuning_config["objective_type"] = tuner.objective_type
            tuning_config["objective_metric_name"] = tuner.objective_metric_name

        parameter_ranges = tuner.hyperparameter_ranges()
        if parameter_ranges is not None:
            tuning_config["parameter_ranges"] = parameter_ranges

        if tuner.auto_parameters is not None:
            tuning_config["auto_parameters"] = tuner.auto_parameters

        if tuner.completion_criteria_config is not None:
            tuning_config["completion_criteria_config"] = (
                tuner.completion_criteria_config.to_input_req()
            )

        tuner_args = {
            "job_name": tuner._current_job_name,
            "tuning_config": tuning_config,
            "tags": tuner.tags,
            "warm_start_config": warm_start_config_req,
            "autotune": tuner.autotune,
        }

        if tuner.estimator is not None:
            tuner_args["training_config"] = cls._prepare_training_config(
                inputs=inputs,
                estimator=tuner.estimator,
                static_hyperparameters=tuner.static_hyperparameters,
                metric_definitions=tuner.metric_definitions,
                instance_configs=tuner.instance_configs,
            )

        if tuner.estimator_dict is not None:
            tuner_args["training_config_list"] = [
                cls._prepare_training_config(
                    inputs.get(estimator_name, None) if inputs is not None else None,
                    tuner.estimator_dict[estimator_name],
                    tuner.static_hyperparameters_dict[estimator_name],
                    tuner.metric_definitions_dict.get(estimator_name, None),
                    estimator_name,
                    tuner.objective_type,
                    tuner.objective_metric_name_dict[estimator_name],
                    tuner.hyperparameter_ranges_dict()[estimator_name],
                    (
                        tuner.instance_configs_dict.get(estimator_name, None)
                        if tuner.instance_configs_dict is not None
                        else None
                    ),
                    (
                        tuner.auto_parameters_dict.get(estimator_name, None)
                        if tuner.auto_parameters_dict is not None
                        else None
                    ),
                )
                for estimator_name in sorted(tuner.estimator_dict.keys())
            ]

        return tuner_args