in src/stepfunctions/steps/sagemaker.py [0:0]
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, output_data_config_path=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
estimator (sagemaker.estimator.EstimatorBase): The estimator for the training step. Can be a `BYO estimator, Framework estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms.html>`_ or `Amazon built-in algorithm estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`_.
job_name (str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms:
* (str or Placeholder) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
hyperparameters: Parameters used for training.
* (dict, optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
* (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders <https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
self.estimator = estimator
self.job_name = job_name
if wait_for_completion:
"""
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
"""
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTrainingJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
"""
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTrainingJob)
# Convert `data` Placeholder to a JSONPath string because sagemaker.workflow.airflow.training_config does not
# accept Placeholder in the `input` argument. We will suffix the 'S3Uri' key in `parameters` with ".$" later.
is_data_placeholder = isinstance(data, Placeholder)
if is_data_placeholder:
data = data.to_jsonpath()
if isinstance(job_name, str):
training_parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
else:
training_parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
training_parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
if estimator.rules != None:
training_parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
if isinstance(job_name, Placeholder):
training_parameters['TrainingJobName'] = job_name
if output_data_config_path is not None:
training_parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
if data is not None and is_data_placeholder:
# Replace the 'S3Uri' key with one that supports JSONpath value.
# Support for uri str only: The list will only contain 1 element
data_uri = training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
if hyperparameters is not None:
if not isinstance(hyperparameters, Placeholder):
if estimator.hyperparameters() is not None:
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
training_parameters['HyperParameters'] = hyperparameters
if experiment_config is not None:
training_parameters['ExperimentConfig'] = experiment_config
if 'S3Operations' in training_parameters:
del training_parameters['S3Operations']
if tags:
training_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update training parameters with input parameters
merge_dicts(training_parameters, kwargs[Field.Parameters.value])
kwargs[Field.Parameters.value] = training_parameters
super(TrainingStep, self).__init__(state_id, **kwargs)