in src/sagemaker/automl/automl.py [0:0]
def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
"""Load job_config, input_config and output config from auto_ml and inputs.
Args:
inputs (str or list[str] or AutoMLInput or list[AutoMLInput]):
if input is string,
it should be the S3 Uri where the training data is stored
and must startwith "s3://".
if the input is a list of AutoMLInputs,
it will be converted into a request dictionary with list of input data sources.
auto_ml (AutoML): an AutoML object that user initiated.
expand_role (str): The expanded role arn that allows for Sagemaker
executionts.
validate_uri (bool): indicate whether to validate the S3 uri.
Returns (dict): a config dictionary that contains input_config, output_config,
job_config and role information.
"""
# JobConfig
# InputDataConfig
# OutputConfig
if isinstance(inputs, AutoMLInput):
input_config = inputs.to_request_dict()
elif isinstance(inputs, list) and all(
isinstance(channel, AutoMLInput) for channel in inputs
):
input_config = []
for channel in inputs:
input_config.extend(channel.to_request_dict())
else:
input_config = cls._format_inputs_to_input_config(
inputs,
validate_uri,
auto_ml.compression_type,
auto_ml.target_attribute_name,
auto_ml.content_type,
auto_ml.s3_data_type,
auto_ml.sample_weight_attribute_name,
)
output_config = _Job._prepare_output_config(auto_ml.output_path, auto_ml.output_kms_key)
role = auto_ml.sagemaker_session.expand_role(auto_ml.role) if expand_role else auto_ml.role
stop_condition = cls._prepare_auto_ml_stop_condition(
auto_ml.max_candidate,
auto_ml.max_runtime_per_training_job_in_seconds,
auto_ml.total_job_runtime_in_seconds,
)
auto_ml_job_config = {
"CompletionCriteria": stop_condition,
"SecurityConfig": {
"EnableInterContainerTrafficEncryption": auto_ml.encrypt_inter_container_traffic
},
}
if auto_ml.volume_kms_key:
auto_ml_job_config["SecurityConfig"]["VolumeKmsKeyId"] = auto_ml.volume_kms_key
if auto_ml.vpc_config:
auto_ml_job_config["SecurityConfig"]["VpcConfig"] = auto_ml.vpc_config
if auto_ml.feature_specification_s3_uri:
auto_ml_job_config["CandidateGenerationConfig"] = {}
auto_ml_job_config["CandidateGenerationConfig"][
"FeatureSpecificationS3Uri"
] = auto_ml.feature_specification_s3_uri
if auto_ml.validation_fraction:
auto_ml_job_config["DataSplitConfig"] = {}
auto_ml_job_config["DataSplitConfig"][
"ValidationFraction"
] = auto_ml.validation_fraction
if auto_ml.mode:
auto_ml_job_config["Mode"] = auto_ml.mode
config = {
"input_config": input_config,
"output_config": output_config,
"auto_ml_job_config": auto_ml_job_config,
"role": role,
"generate_candidate_definitions_only": auto_ml.generate_candidate_definitions_only,
}
auto_ml_model_deploy_config = {}
if auto_ml.auto_generate_endpoint_name is not None:
auto_ml_model_deploy_config["AutoGenerateEndpointName"] = (
auto_ml.auto_generate_endpoint_name
)
if not auto_ml.auto_generate_endpoint_name and auto_ml.endpoint_name is not None:
auto_ml_model_deploy_config["EndpointName"] = auto_ml.endpoint_name
if auto_ml_model_deploy_config:
config["model_deploy_config"] = auto_ml_model_deploy_config
return config