def _get_train_args()

in src/sagemaker/estimator.py [0:0]


    def _get_train_args(cls, estimator, inputs, experiment_config):
        """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.

        Args:
            estimator (sagemaker.estimator.EstimatorBase): Estimator object
                created by the user.
            inputs (str): Parameters used when called
                :meth:`~sagemaker.estimator.EstimatorBase.fit`.
            experiment_config (dict[str, str]): Experiment management configuration.
                Optionally, the dict can contain four keys:
                'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
                The behavior of setting these keys is as follows:
                * If `ExperimentName` is supplied but `TrialName` is not a Trial will be
                automatically created and the job's Trial Component associated with the Trial.
                * If `TrialName` is supplied and the Trial already exists the job's Trial Component
                will be associated with the Trial.
                * If both `ExperimentName` and `TrialName` are not supplied the trial component
                will be unassociated.
                * `TrialComponentDisplayName` is used for display in Studio.
                * `RunName` is used to record an experiment run.

        Returns:
            Dict: dict for `sagemaker.session.Session.train` method
        """

        local_mode = estimator.sagemaker_session.local_mode
        model_uri = estimator.model_uri

        # Allow file:// input only in local mode
        if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri):
            if not local_mode:
                raise ValueError(
                    "File URIs are supported in local mode only. Please use a S3 URI instead."
                )
        config = _Job._load_config(inputs, estimator)

        current_hyperparameters = estimator.hyperparameters()
        if current_hyperparameters is not None:
            hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()}

        train_args = config.copy()
        train_args["input_mode"] = estimator.input_mode
        train_args["job_name"] = estimator._current_job_name
        train_args["hyperparameters"] = hyperparameters
        train_args["tags"] = estimator.tags
        train_args["metric_definitions"] = estimator.metric_definitions
        train_args["experiment_config"] = experiment_config
        train_args["environment"] = estimator.environment

        if isinstance(inputs, TrainingInput):
            if "InputMode" in inputs.config:
                logger.debug(
                    "Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
                    inputs.config["InputMode"],
                )
                train_args["input_mode"] = inputs.config["InputMode"]

        # enable_network_isolation may be a pipeline variable place holder object
        # which is parsed in execution time

        # Should be defaulted to False
        train_args["enable_network_isolation"] = False

        # Only change it if it's explicitly passed so the sagemaker config
        # doesn't override the kwarg.
        if estimator.enable_network_isolation() is not None:
            train_args["enable_network_isolation"] = estimator.enable_network_isolation()

        if estimator.max_retry_attempts is not None:
            train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
        else:
            train_args["retry_strategy"] = None

        if estimator.training_repository_access_mode is not None:
            training_image_config = {
                "TrainingRepositoryAccessMode": estimator.training_repository_access_mode
            }
            if estimator.training_repository_credentials_provider_arn is not None:
                training_image_config["TrainingRepositoryAuthConfig"] = {}
                training_image_config["TrainingRepositoryAuthConfig"][
                    "TrainingRepositoryCredentialsProviderArn"
                ] = estimator.training_repository_credentials_provider_arn
            train_args["training_image_config"] = training_image_config

        if estimator.enable_infra_check is not None:
            infra_check_config = {"EnableInfraCheck": estimator.enable_infra_check}
            train_args["infra_check_config"] = infra_check_config

        if estimator.container_entry_point is not None:
            train_args["container_entry_point"] = estimator.container_entry_point

        if estimator.container_arguments is not None:
            train_args["container_arguments"] = estimator.container_arguments

        # encrypt_inter_container_traffic may be a pipeline variable place holder object
        # which is parsed in execution time
        # This does not check config because the EstimatorBase constuctor already did that check
        if estimator.encrypt_inter_container_traffic:
            train_args["encrypt_inter_container_traffic"] = (
                estimator.encrypt_inter_container_traffic
            )

        if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
            train_args["algorithm_arn"] = estimator.algorithm_arn
        else:
            train_args["image_uri"] = estimator.training_image_uri()

        if estimator.debugger_rule_configs:
            train_args["debugger_rule_configs"] = estimator.debugger_rule_configs

        if estimator.debugger_hook_config:
            estimator.debugger_hook_config.collection_configs = estimator.collection_configs
            train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()

        if estimator.tensorboard_output_config:
            train_args["tensorboard_output_config"] = (
                estimator.tensorboard_output_config._to_request_dict()
            )

        cls._add_spot_checkpoint_args(local_mode, estimator, train_args)

        if estimator.enable_sagemaker_metrics is not None:
            train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics

        if estimator.profiler_rule_configs:
            train_args["profiler_rule_configs"] = estimator.profiler_rule_configs

        if estimator.profiler_config:
            train_args["profiler_config"] = estimator.profiler_config._to_request_dict()

        if estimator.get_remote_debug_config() is not None:
            train_args["remote_debug_config"] = estimator.get_remote_debug_config()

        if estimator.get_session_chaining_config() is not None:
            train_args["session_chaining_config"] = estimator.get_session_chaining_config()

        return train_args