in src/sagemaker/estimator.py [0:0]
def _get_train_args(cls, estimator, inputs, experiment_config):
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
Args:
estimator (sagemaker.estimator.EstimatorBase): Estimator object
created by the user.
inputs (str): Parameters used when called
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
experiment_config (dict[str, str]): Experiment management configuration.
Optionally, the dict can contain four keys:
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
will be associated with the Trial.
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
Returns:
Dict: dict for `sagemaker.session.Session.train` method
"""
local_mode = estimator.sagemaker_session.local_mode
model_uri = estimator.model_uri
# Allow file:// input only in local mode
if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri):
if not local_mode:
raise ValueError(
"File URIs are supported in local mode only. Please use a S3 URI instead."
)
config = _Job._load_config(inputs, estimator)
current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()}
train_args = config.copy()
train_args["input_mode"] = estimator.input_mode
train_args["job_name"] = estimator._current_job_name
train_args["hyperparameters"] = hyperparameters
train_args["tags"] = estimator.tags
train_args["metric_definitions"] = estimator.metric_definitions
train_args["experiment_config"] = experiment_config
train_args["environment"] = estimator.environment
if isinstance(inputs, TrainingInput):
if "InputMode" in inputs.config:
logger.debug(
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
inputs.config["InputMode"],
)
train_args["input_mode"] = inputs.config["InputMode"]
# enable_network_isolation may be a pipeline variable place holder object
# which is parsed in execution time
# Should be defaulted to False
train_args["enable_network_isolation"] = False
# Only change it if it's explicitly passed so the sagemaker config
# doesn't override the kwarg.
if estimator.enable_network_isolation() is not None:
train_args["enable_network_isolation"] = estimator.enable_network_isolation()
if estimator.max_retry_attempts is not None:
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
else:
train_args["retry_strategy"] = None
if estimator.training_repository_access_mode is not None:
training_image_config = {
"TrainingRepositoryAccessMode": estimator.training_repository_access_mode
}
if estimator.training_repository_credentials_provider_arn is not None:
training_image_config["TrainingRepositoryAuthConfig"] = {}
training_image_config["TrainingRepositoryAuthConfig"][
"TrainingRepositoryCredentialsProviderArn"
] = estimator.training_repository_credentials_provider_arn
train_args["training_image_config"] = training_image_config
if estimator.enable_infra_check is not None:
infra_check_config = {"EnableInfraCheck": estimator.enable_infra_check}
train_args["infra_check_config"] = infra_check_config
if estimator.container_entry_point is not None:
train_args["container_entry_point"] = estimator.container_entry_point
if estimator.container_arguments is not None:
train_args["container_arguments"] = estimator.container_arguments
# encrypt_inter_container_traffic may be a pipeline variable place holder object
# which is parsed in execution time
# This does not check config because the EstimatorBase constuctor already did that check
if estimator.encrypt_inter_container_traffic:
train_args["encrypt_inter_container_traffic"] = (
estimator.encrypt_inter_container_traffic
)
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
train_args["algorithm_arn"] = estimator.algorithm_arn
else:
train_args["image_uri"] = estimator.training_image_uri()
if estimator.debugger_rule_configs:
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
if estimator.debugger_hook_config:
estimator.debugger_hook_config.collection_configs = estimator.collection_configs
train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()
if estimator.tensorboard_output_config:
train_args["tensorboard_output_config"] = (
estimator.tensorboard_output_config._to_request_dict()
)
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
if estimator.enable_sagemaker_metrics is not None:
train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics
if estimator.profiler_rule_configs:
train_args["profiler_rule_configs"] = estimator.profiler_rule_configs
if estimator.profiler_config:
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()
if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
if estimator.get_session_chaining_config() is not None:
train_args["session_chaining_config"] = estimator.get_session_chaining_config()
return train_args