in src/sagemaker/automl/automlv2.py [0:0]
def _load_config(cls, inputs, auto_ml, expand_role=True):
"""Load job_config, input_config and output config from auto_ml and inputs.
Args:
inputs (AutoMLDataChannel or list[AutoMLDataChannel]): Parameters used when called
:meth:`~sagemaker.automl.AutoML.fit`.
auto_ml (AutoMLV2): an AutoMLV2 object that user initiated.
expand_role (str): The expanded role arn that allows for Sagemaker
executionts.
validate_uri (bool): indicate whether to validate the S3 uri.
Returns (dict): a config dictionary that contains input_config, output_config,
problem_config and role information.
"""
if not inputs:
msg = (
"Cannot format input {}. Expecting an AutoMLDataChannel or "
"a list of AutoMLDataChannel or a LocalAutoMLDataChannel or a list of "
"LocalAutoMLDataChannel."
)
raise ValueError(msg.format(inputs))
if isinstance(inputs, AutoMLDataChannel):
input_config = [inputs.to_request_dict()]
elif isinstance(inputs, list) and all(
isinstance(channel, AutoMLDataChannel) for channel in inputs
):
input_config = [channel.to_request_dict() for channel in inputs]
output_config = _Job._prepare_output_config(auto_ml.output_path, auto_ml.output_kms_key)
role = auto_ml.sagemaker_session.expand_role(auto_ml.role) if expand_role else auto_ml.role
problem_config = auto_ml.problem_config.to_request_dict()
config = {
"input_config": input_config,
"output_config": output_config,
"problem_config": problem_config,
"role": role,
"job_objective": auto_ml.job_objective,
}
if (
auto_ml.volume_kms_key
or auto_ml.vpc_config
or auto_ml.encrypt_inter_container_traffic is not None
):
config["security_config"] = {}
if auto_ml.volume_kms_key:
config["security_config"]["VolumeKmsKeyId"] = auto_ml.volume_kms_key
if auto_ml.vpc_config:
config["security_config"]["VpcConfig"] = auto_ml.vpc_config
if auto_ml.encrypt_inter_container_traffic is not None:
config["security_config"][
"EnableInterContainerTrafficEncryption"
] = auto_ml.encrypt_inter_container_traffic
# Model deploy config
auto_ml_model_deploy_config = {}
if auto_ml.auto_generate_endpoint_name is not None:
auto_ml_model_deploy_config["AutoGenerateEndpointName"] = (
auto_ml.auto_generate_endpoint_name
)
if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None:
auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name
if auto_ml_model_deploy_config:
config["model_deploy_config"] = auto_ml_model_deploy_config
# Data split config
if auto_ml.validation_fraction is not None:
config["data_split_config"] = {"ValidationFraction": auto_ml.validation_fraction}
return config