def _get_train_args()

in src/sagemaker/estimator.py [0:0]


    def _get_train_args(cls, estimator, inputs, experiment_config):
        """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.

        Args:
            estimator (sagemaker.estimator.EstimatorBase): Estimator object
                created by the user.
            inputs (str): Parameters used when called
                :meth:`~sagemaker.estimator.EstimatorBase.fit`.
            experiment_config (dict[str, str]): Experiment management configuration.
                Optionally, the dict can contain three keys:
                'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
                The behavior of setting these keys is as follows:
                * If `ExperimentName` is supplied but `TrialName` is not a Trial will be
                automatically created and the job's Trial Component associated with the Trial.
                * If `TrialName` is supplied and the Trial already exists the job's Trial Component
                will be associated with the Trial.
                * If both `ExperimentName` and `TrialName` are not supplied the trial component
                will be unassociated.
                * `TrialComponentDisplayName` is used for display in Studio.

        Returns:
            Dict: dict for `sagemaker.session.Session.train` method
        """

        local_mode = estimator.sagemaker_session.local_mode
        model_uri = estimator.model_uri

        # Allow file:// input only in local mode
        if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri):
            if not local_mode:
                raise ValueError(
                    "File URIs are supported in local mode only. Please use a S3 URI instead."
                )

        config = _Job._load_config(inputs, estimator)

        current_hyperparameters = estimator.hyperparameters()
        if current_hyperparameters is not None:
            hyperparameters = {
                str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
                for (k, v) in current_hyperparameters.items()
            }

        train_args = config.copy()
        train_args["input_mode"] = estimator.input_mode
        train_args["job_name"] = estimator._current_job_name
        train_args["hyperparameters"] = hyperparameters
        train_args["tags"] = estimator.tags
        train_args["metric_definitions"] = estimator.metric_definitions
        train_args["experiment_config"] = experiment_config
        train_args["environment"] = estimator.environment

        if isinstance(inputs, TrainingInput):
            if "InputMode" in inputs.config:
                logger.debug(
                    "Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
                    inputs.config["InputMode"],
                )
                train_args["input_mode"] = inputs.config["InputMode"]

        if estimator.enable_network_isolation():
            train_args["enable_network_isolation"] = True

        if estimator.max_retry_attempts is not None:
            train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
        else:
            train_args["retry_strategy"] = None

        if estimator.encrypt_inter_container_traffic:
            train_args["encrypt_inter_container_traffic"] = True

        if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
            train_args["algorithm_arn"] = estimator.algorithm_arn
        else:
            train_args["image_uri"] = estimator.training_image_uri()

        if estimator.debugger_rule_configs:
            train_args["debugger_rule_configs"] = estimator.debugger_rule_configs

        if estimator.debugger_hook_config:
            estimator.debugger_hook_config.collection_configs = estimator.collection_configs
            train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()

        if estimator.tensorboard_output_config:
            train_args[
                "tensorboard_output_config"
            ] = estimator.tensorboard_output_config._to_request_dict()

        cls._add_spot_checkpoint_args(local_mode, estimator, train_args)

        if estimator.enable_sagemaker_metrics is not None:
            train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics

        if estimator.profiler_rule_configs:
            train_args["profiler_rule_configs"] = estimator.profiler_rule_configs

        if estimator.profiler_config:
            train_args["profiler_config"] = estimator.profiler_config._to_request_dict()

        return train_args