def training_base_config()

in src/sagemaker/workflow/airflow.py [0:0]


def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):  # noqa: C901
    """Export Airflow base training config from an estimator

    Args:
        estimator (sagemaker.estimator.EstimatorBase): The estimator to export
            training config from. Can be a BYO estimator, Framework estimator or
            Amazon algorithm estimator.
        inputs: Information about the training data. Please refer to the ``fit()``
            method of
                the associated estimator, as this can take any of the following
                forms:

            * (str) - The S3 location where training data is saved.

            * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
                  channels for training data, you can specify a dict mapping channel names to
                  strings or :func:`~sagemaker.inputs.TrainingInput` objects.

            * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
                  provide additional information about the training dataset. See
                  :func:`sagemaker.inputs.TrainingInput` for full details.

            * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
                  Amazon :class:~`Record` objects serialized and stored in S3.
                  For use with an estimator for an Amazon algorithm.

            * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
                  :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
                  where each instance is a different channel of training data.
        job_name (str): Specify a training job name if needed.
        mini_batch_size (int): Specify this argument only when estimator is a
            built-in estimator of an Amazon algorithm. For other estimators,
            batch size should be specified in the estimator.

    Returns:
        dict: Training config that can be directly used by
        SageMakerTrainingOperator in Airflow.
    """
    if isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
        estimator.prepare_workflow_for_training(
            records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
        )
    else:
        estimator.prepare_workflow_for_training(job_name=job_name)

    s3_operations = {}

    if job_name is not None:
        estimator._current_job_name = job_name
    else:
        base_name = estimator.base_job_name or utils.base_name_from_image(
            estimator.training_image_uri(),
            default_base_name=EstimatorBase.JOB_CLASS_NAME,
        )
        estimator._current_job_name = utils.name_from_base(base_name)

    if estimator.output_path is None:
        estimator.output_path = s3.s3_path_join(
            "s3://",
            estimator.sagemaker_session.default_bucket(),
            estimator.sagemaker_session.default_bucket_prefix,
            with_end_slash=True,
        )

    if isinstance(estimator, sagemaker.estimator.Framework):
        prepare_framework(estimator, s3_operations)

    elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
        prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
    job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)

    train_config = {
        "AlgorithmSpecification": {
            "TrainingImage": estimator.training_image_uri(),
            "TrainingInputMode": estimator.input_mode,
        },
        "OutputDataConfig": job_config["output_config"],
        "StoppingCondition": job_config["stop_condition"],
        "ResourceConfig": job_config["resource_config"],
        "RoleArn": job_config["role"],
    }

    if job_config["input_config"] is not None:
        train_config["InputDataConfig"] = job_config["input_config"]

    if job_config["vpc_config"] is not None:
        train_config["VpcConfig"] = job_config["vpc_config"]

    if estimator.use_spot_instances:
        # estimator.use_spot_instances may be a Pipeline ParameterBoolean object
        # which is parsed during the Pipeline execution runtime
        train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances

    if estimator.hyperparameters() is not None:
        hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

    if hyperparameters and len(hyperparameters) > 0:
        train_config["HyperParameters"] = hyperparameters

    if s3_operations:
        train_config["S3Operations"] = s3_operations

    if (estimator.checkpoint_local_path is not None) & (estimator.checkpoint_s3_uri is not None):
        train_config["CheckpointConfig"] = {
            "LocalPath": estimator.checkpoint_local_path,
            "S3Uri": estimator.checkpoint_s3_uri,
        }
    return train_config