in src/sagemaker/workflow/airflow.py [0:0]
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None): # noqa: C901
"""Export Airflow base training config from an estimator
Args:
estimator (sagemaker.estimator.EstimatorBase): The estimator to export
training config from. Can be a BYO estimator, Framework estimator or
Amazon algorithm estimator.
inputs: Information about the training data. Please refer to the ``fit()``
method of
the associated estimator, as this can take any of the following
forms:
* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
job_name (str): Specify a training job name if needed.
mini_batch_size (int): Specify this argument only when estimator is a
built-in estimator of an Amazon algorithm. For other estimators,
batch size should be specified in the estimator.
Returns:
dict: Training config that can be directly used by
SageMakerTrainingOperator in Airflow.
"""
if isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
estimator.prepare_workflow_for_training(
records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
)
else:
estimator.prepare_workflow_for_training(job_name=job_name)
s3_operations = {}
if job_name is not None:
estimator._current_job_name = job_name
else:
base_name = estimator.base_job_name or utils.base_name_from_image(
estimator.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
)
estimator._current_job_name = utils.name_from_base(base_name)
if estimator.output_path is None:
estimator.output_path = s3.s3_path_join(
"s3://",
estimator.sagemaker_session.default_bucket(),
estimator.sagemaker_session.default_bucket_prefix,
with_end_slash=True,
)
if isinstance(estimator, sagemaker.estimator.Framework):
prepare_framework(estimator, s3_operations)
elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)
train_config = {
"AlgorithmSpecification": {
"TrainingImage": estimator.training_image_uri(),
"TrainingInputMode": estimator.input_mode,
},
"OutputDataConfig": job_config["output_config"],
"StoppingCondition": job_config["stop_condition"],
"ResourceConfig": job_config["resource_config"],
"RoleArn": job_config["role"],
}
if job_config["input_config"] is not None:
train_config["InputDataConfig"] = job_config["input_config"]
if job_config["vpc_config"] is not None:
train_config["VpcConfig"] = job_config["vpc_config"]
if estimator.use_spot_instances:
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances
if estimator.hyperparameters() is not None:
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
if hyperparameters and len(hyperparameters) > 0:
train_config["HyperParameters"] = hyperparameters
if s3_operations:
train_config["S3Operations"] = s3_operations
if (estimator.checkpoint_local_path is not None) & (estimator.checkpoint_s3_uri is not None):
train_config["CheckpointConfig"] = {
"LocalPath": estimator.checkpoint_local_path,
"S3Uri": estimator.checkpoint_s3_uri,
}
return train_config