in src/sagemaker/estimator.py [0:0]
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
"""Convert the job description to init params that can be handled by the class constructor.
Args:
job_details: the returned job details from a describe_training_job
API call.
model_channel_name (str): Name of the channel where pre-trained
model data will be downloaded.
Returns:
dictionary: The transformed init_params
"""
init_params = dict()
init_params["role"] = job_details["RoleArn"]
init_params["instance_count"] = job_details["ResourceConfig"]["InstanceCount"]
init_params["instance_type"] = job_details["ResourceConfig"]["InstanceType"]
init_params["volume_size"] = job_details["ResourceConfig"]["VolumeSizeInGB"]
init_params["max_run"] = job_details["StoppingCondition"]["MaxRuntimeInSeconds"]
init_params["input_mode"] = job_details["AlgorithmSpecification"]["TrainingInputMode"]
init_params["base_job_name"] = base_from_name(job_details["TrainingJobName"])
init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"]
init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"]
if "EnableNetworkIsolation" in job_details:
init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]
if "KeepAlivePeriodInSeconds" in job_details["ResourceConfig"]:
init_params["keep_alive_period_in_seconds"] = job_details["ResourceConfig"][
"KeepAlivePeriodInSeconds"
]
if "TrainingPlanArn" in job_details["ResourceConfig"]:
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]
has_hps = "HyperParameters" in job_details
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
if "AlgorithmName" in job_details["AlgorithmSpecification"]:
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
if "TrainingImageConfig" in job_details["AlgorithmSpecification"]:
init_params["training_repository_access_mode"] = job_details[
"AlgorithmSpecification"
]["TrainingImageConfig"].get("TrainingRepositoryAccessMode")
init_params["training_repository_credentials_provider_arn"] = (
job_details["AlgorithmSpecification"]["TrainingImageConfig"]
.get("TrainingRepositoryAuthConfig", {})
.get("TrainingRepositoryCredentialsProviderArn")
)
else:
raise RuntimeError(
"Invalid AlgorithmSpecification. Either TrainingImage or "
"AlgorithmName is expected. None was found."
)
if "MetricDefinitons" in job_details["AlgorithmSpecification"]:
init_params["metric_definitions"] = job_details["AlgorithmSpecification"][
"MetricsDefinition"
]
if "ContainerEntrypoint" in job_details["AlgorithmSpecification"]:
init_params["container_entry_point"] = job_details["AlgorithmSpecification"][
"ContainerEntrypoint"
]
if "ContainerArguments" in job_details["AlgorithmSpecification"]:
init_params["container_arguments"] = job_details["AlgorithmSpecification"][
"ContainerArguments"
]
if "EnableInterContainerTrafficEncryption" in job_details:
init_params["encrypt_inter_container_traffic"] = job_details[
"EnableInterContainerTrafficEncryption"
]
if "InfraCheckConfig" in job_details:
init_params["enable_infra_check"] = job_details["InfraCheckConfig"].get(
"EnableInfraCheck"
)
subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY))
if subnets:
init_params["subnets"] = subnets
if security_group_ids:
init_params["security_group_ids"] = security_group_ids
if "InputDataConfig" in job_details and model_channel_name:
for channel in job_details["InputDataConfig"]:
if channel["ChannelName"] == model_channel_name:
init_params["model_channel_name"] = model_channel_name
init_params["model_uri"] = channel["DataSource"]["S3DataSource"]["S3Uri"]
break
if job_details.get("EnableManagedSpotTraining", False):
init_params["use_spot_instances"] = True
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
if max_wait:
init_params["max_wait"] = max_wait
if job_details.get("RetryStrategy", False):
init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get(
"MaximumRetryAttempts"
)
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
if max_wait:
init_params["max_wait"] = max_wait
if "RemoteDebugConfig" in job_details:
init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
"EnableRemoteDebug"
)
return init_params