in src/sagemaker/estimator.py [0:0]
def _get_train_args(cls, estimator, inputs, experiment_config):
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
Args:
estimator (sagemaker.estimator.EstimatorBase): Estimator object
created by the user.
inputs (str): Parameters used when called
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
experiment_config (dict[str, str]): Experiment management configuration.
Optionally, the dict can contain three keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
* If `TrialName` is supplied and the Trial already exists the job's Trial Component
will be associated with the Trial.
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
Returns:
Dict: dict for `sagemaker.session.Session.train` method
"""
local_mode = estimator.sagemaker_session.local_mode
model_uri = estimator.model_uri
# Allow file:// input only in local mode
if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri):
if not local_mode:
raise ValueError(
"File URIs are supported in local mode only. Please use a S3 URI instead."
)
config = _Job._load_config(inputs, estimator)
current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
for (k, v) in current_hyperparameters.items()
}
train_args = config.copy()
train_args["input_mode"] = estimator.input_mode
train_args["job_name"] = estimator._current_job_name
train_args["hyperparameters"] = hyperparameters
train_args["tags"] = estimator.tags
train_args["metric_definitions"] = estimator.metric_definitions
train_args["experiment_config"] = experiment_config
train_args["environment"] = estimator.environment
if isinstance(inputs, TrainingInput):
if "InputMode" in inputs.config:
logger.debug(
"Selecting TrainingInput's input_mode (%s) for TrainingInputMode.",
inputs.config["InputMode"],
)
train_args["input_mode"] = inputs.config["InputMode"]
if estimator.enable_network_isolation():
train_args["enable_network_isolation"] = True
if estimator.max_retry_attempts is not None:
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
else:
train_args["retry_strategy"] = None
if estimator.encrypt_inter_container_traffic:
train_args["encrypt_inter_container_traffic"] = True
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
train_args["algorithm_arn"] = estimator.algorithm_arn
else:
train_args["image_uri"] = estimator.training_image_uri()
if estimator.debugger_rule_configs:
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
if estimator.debugger_hook_config:
estimator.debugger_hook_config.collection_configs = estimator.collection_configs
train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()
if estimator.tensorboard_output_config:
train_args[
"tensorboard_output_config"
] = estimator.tensorboard_output_config._to_request_dict()
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
if estimator.enable_sagemaker_metrics is not None:
train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics
if estimator.profiler_rule_configs:
train_args["profiler_rule_configs"] = estimator.profiler_rule_configs
if estimator.profiler_config:
train_args["profiler_config"] = estimator.profiler_config._to_request_dict()
return train_args