in src/sagemaker/automl/automlv2.py [0:0]
def attach(cls, auto_ml_job_name, sagemaker_session=None):
"""Attach to an existing AutoML job.
Creates and returns a AutoML bound to an existing automl job.
Args:
auto_ml_job_name (str): AutoML job name
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions (default: None). If not
specified, the one originally associated with the ``AutoML`` instance is used.
Returns:
sagemaker.automl.AutoML: A ``AutoMLV2`` instance with the attached automl job.
"""
sagemaker_session = sagemaker_session or Session()
auto_ml_job_desc = sagemaker_session.describe_auto_ml_job_v2(auto_ml_job_name)
automl_job_tags = sagemaker_session.sagemaker_client.list_tags(
ResourceArn=auto_ml_job_desc["AutoMLJobArn"]
)["Tags"]
inputs = [
AutoMLDataChannel.from_response_dict(channel)
for channel in auto_ml_job_desc["AutoMLJobInputDataConfig"]
]
problem_type = auto_ml_job_desc["AutoMLProblemTypeConfigName"]
problem_config = None
if problem_type == "ImageClassification":
problem_config = AutoMLImageClassificationConfig.from_response_dict(
auto_ml_job_desc["AutoMLProblemTypeConfig"]["ImageClassificationJobConfig"]
)
elif problem_type == "TextClassification":
problem_config = AutoMLTextClassificationConfig.from_response_dict(
auto_ml_job_desc["AutoMLProblemTypeConfig"]["TextClassificationJobConfig"]
)
elif problem_type == "TimeSeriesForecasting":
problem_config = AutoMLTimeSeriesForecastingConfig.from_response_dict(
auto_ml_job_desc["AutoMLProblemTypeConfig"]["TimeSeriesForecastingJobConfig"]
)
elif problem_type == "Tabular":
problem_config = AutoMLTabularConfig.from_response_dict(
auto_ml_job_desc["AutoMLProblemTypeConfig"]["TabularJobConfig"]
)
elif problem_type == "TextGeneration":
problem_config = AutoMLTextGenerationConfig.from_response_dict(
auto_ml_job_desc["AutoMLProblemTypeConfig"]["TextGenerationJobConfig"]
)
amlj = AutoMLV2(
role=auto_ml_job_desc["RoleArn"],
problem_config=problem_config,
output_path=auto_ml_job_desc["OutputDataConfig"]["S3OutputPath"],
output_kms_key=auto_ml_job_desc["OutputDataConfig"].get("KmsKeyId"),
base_job_name=auto_ml_job_name,
sagemaker_session=sagemaker_session,
volume_kms_key=auto_ml_job_desc.get("SecurityConfig", {}).get("VolumeKmsKeyId"),
# Do not override encrypt_inter_container_traffic from config because this info
# is pulled from an existing automl job
encrypt_inter_container_traffic=auto_ml_job_desc.get("SecurityConfig", {}).get(
"EnableInterContainerTrafficEncryption"
),
vpc_config=auto_ml_job_desc.get("SecurityConfig", {}).get("VpcConfig"),
job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}),
auto_generate_endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get(
"AutoGenerateEndpointName", False
),
endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get("EndpointName"),
validation_fraction=auto_ml_job_desc.get("DataSplitConfig", {}).get(
"ValidationFraction"
),
tags=automl_job_tags,
)
amlj.current_job_name = auto_ml_job_name
amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201
amlj._auto_ml_job_desc = auto_ml_job_desc
amlj.inputs = inputs
return amlj