in src/sagemaker/automl/automl.py [0:0]
def attach(cls, auto_ml_job_name, sagemaker_session=None):
"""Attach to an existing AutoML job.
Creates and returns a AutoML bound to an existing automl job.
Args:
auto_ml_job_name (str): AutoML job name
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions (default: None). If not
specified, the one originally associated with the ``AutoML`` instance is used.
Returns:
sagemaker.automl.AutoML: A ``AutoML`` instance with the attached automl job.
"""
sagemaker_session = sagemaker_session or Session()
auto_ml_job_desc = sagemaker_session.describe_auto_ml_job(auto_ml_job_name)
automl_job_tags = sagemaker_session.sagemaker_client.list_tags(
ResourceArn=auto_ml_job_desc["AutoMLJobArn"]
)["Tags"]
amlj = AutoML(
role=auto_ml_job_desc["RoleArn"],
target_attribute_name=auto_ml_job_desc["InputDataConfig"][0]["TargetAttributeName"],
output_kms_key=auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"),
output_path=auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"],
base_job_name=auto_ml_job_name,
compression_type=auto_ml_job_desc["InputDataConfig"][0].get("CompressionType"),
sagemaker_session=sagemaker_session,
volume_kms_key=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("SecurityConfig", {})
.get("VolumeKmsKeyId"),
# Do not override encrypt_inter_container_traffic from config because this info
# is pulled from an existing automl job
encrypt_inter_container_traffic=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("SecurityConfig", {})
.get("EnableInterContainerTrafficEncryption", False),
vpc_config=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("SecurityConfig", {})
.get("VpcConfig"),
problem_type=auto_ml_job_desc.get("ProblemType"),
max_candidates=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("CompletionCriteria", {})
.get("MaxCandidates"),
max_runtime_per_training_job_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("CompletionCriteria", {})
.get("MaxRuntimePerTrainingJobInSeconds"),
total_job_runtime_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("CompletionCriteria", {})
.get("MaxAutoMLJobRuntimeInSeconds"),
job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}),
generate_candidate_definitions_only=auto_ml_job_desc.get(
"GenerateCandidateDefinitionsOnly", False
),
tags=automl_job_tags,
content_type=auto_ml_job_desc["InputDataConfig"][0].get("ContentType"),
s3_data_type=auto_ml_job_desc["InputDataConfig"][0]["DataSource"]["S3DataSource"].get(
"S3DataType"
),
feature_specification_s3_uri=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("CandidateGenerationConfig", {})
.get("FeatureSpecificationS3Uri"),
validation_fraction=auto_ml_job_desc.get("AutoMLJobConfig", {})
.get("DataSplitConfig", {})
.get("ValidationFraction"),
mode=auto_ml_job_desc.get("AutoMLJobConfig", {}).get("Mode", "HYPERPARAMETER_TUNING"),
auto_generate_endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get(
"AutoGenerateEndpointName", False
),
endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get("EndpointName"),
sample_weight_attribute_name=auto_ml_job_desc["InputDataConfig"][0].get(
"SampleWeightAttributeName", None
),
)
amlj.current_job_name = auto_ml_job_name
amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201
amlj._auto_ml_job_desc = auto_ml_job_desc
return amlj