def _prepare_init_params_from_job_description()

in src/sagemaker/estimator.py [0:0]


    def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
        """Convert the job description to init params that can be handled by the class constructor.

        Args:
            job_details: the returned job details from a describe_training_job
                API call.
            model_channel_name (str): Name of the channel where pre-trained
                model data will be downloaded.

        Returns:
            dictionary: The transformed init_params
        """
        init_params = dict()

        init_params["role"] = job_details["RoleArn"]
        init_params["instance_count"] = job_details["ResourceConfig"]["InstanceCount"]
        init_params["instance_type"] = job_details["ResourceConfig"]["InstanceType"]
        init_params["volume_size"] = job_details["ResourceConfig"]["VolumeSizeInGB"]
        init_params["max_run"] = job_details["StoppingCondition"]["MaxRuntimeInSeconds"]
        init_params["input_mode"] = job_details["AlgorithmSpecification"]["TrainingInputMode"]
        init_params["base_job_name"] = base_from_name(job_details["TrainingJobName"])
        init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"]
        init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"]
        if "EnableNetworkIsolation" in job_details:
            init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]

        if "KeepAlivePeriodInSeconds" in job_details["ResourceConfig"]:
            init_params["keep_alive_period_in_seconds"] = job_details["ResourceConfig"][
                "KeepAlivePeriodInSeconds"
            ]

        if "TrainingPlanArn" in job_details["ResourceConfig"]:
            init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]

        has_hps = "HyperParameters" in job_details
        init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

        if "AlgorithmName" in job_details["AlgorithmSpecification"]:
            init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
        elif "TrainingImage" in job_details["AlgorithmSpecification"]:
            init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
            if "TrainingImageConfig" in job_details["AlgorithmSpecification"]:
                init_params["training_repository_access_mode"] = job_details[
                    "AlgorithmSpecification"
                ]["TrainingImageConfig"].get("TrainingRepositoryAccessMode")
                init_params["training_repository_credentials_provider_arn"] = (
                    job_details["AlgorithmSpecification"]["TrainingImageConfig"]
                    .get("TrainingRepositoryAuthConfig", {})
                    .get("TrainingRepositoryCredentialsProviderArn")
                )
        else:
            raise RuntimeError(
                "Invalid AlgorithmSpecification. Either TrainingImage or "
                "AlgorithmName is expected. None was found."
            )

        if "MetricDefinitons" in job_details["AlgorithmSpecification"]:
            init_params["metric_definitions"] = job_details["AlgorithmSpecification"][
                "MetricsDefinition"
            ]

        if "ContainerEntrypoint" in job_details["AlgorithmSpecification"]:
            init_params["container_entry_point"] = job_details["AlgorithmSpecification"][
                "ContainerEntrypoint"
            ]

        if "ContainerArguments" in job_details["AlgorithmSpecification"]:
            init_params["container_arguments"] = job_details["AlgorithmSpecification"][
                "ContainerArguments"
            ]

        if "EnableInterContainerTrafficEncryption" in job_details:
            init_params["encrypt_inter_container_traffic"] = job_details[
                "EnableInterContainerTrafficEncryption"
            ]

        if "InfraCheckConfig" in job_details:
            init_params["enable_infra_check"] = job_details["InfraCheckConfig"].get(
                "EnableInfraCheck"
            )

        subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY))
        if subnets:
            init_params["subnets"] = subnets
        if security_group_ids:
            init_params["security_group_ids"] = security_group_ids

        if "InputDataConfig" in job_details and model_channel_name:
            for channel in job_details["InputDataConfig"]:
                if channel["ChannelName"] == model_channel_name:
                    init_params["model_channel_name"] = model_channel_name
                    init_params["model_uri"] = channel["DataSource"]["S3DataSource"]["S3Uri"]
                    break

        if job_details.get("EnableManagedSpotTraining", False):
            init_params["use_spot_instances"] = True
            max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
            if max_wait:
                init_params["max_wait"] = max_wait

        if job_details.get("RetryStrategy", False):
            init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get(
                "MaximumRetryAttempts"
            )
            max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
            if max_wait:
                init_params["max_wait"] = max_wait

        if "RemoteDebugConfig" in job_details:
            init_params["enable_remote_debug"] = job_details["RemoteDebugConfig"].get(
                "EnableRemoteDebug"
            )
        return init_params