in src/stepfunctions/steps/sagemaker.py [0:0]
def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to use in the TuningStep.
job_name (str or Placeholder): Specify a tuning job name. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator in the tuner, as this can take any of the following forms:
* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
provide additional information about the training dataset. See
:func:`sagemaker.inputs.TrainingInput` for full details.
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:`Record` objects serialized and stored in S3.
For use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders <https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
if wait_for_completion:
"""
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
"""
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateHyperParameterTuningJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
"""
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateHyperParameterTuningJob)
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
if job_name is not None:
tuning_parameters['HyperParameterTuningJobName'] = job_name
if 'S3Operations' in tuning_parameters:
del tuning_parameters['S3Operations']
if tags:
tuning_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update tuning parameters with input parameters
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
kwargs[Field.Parameters.value] = tuning_parameters
super(TuningStep, self).__init__(state_id, **kwargs)