def __init__()

in src/stepfunctions/steps/sagemaker.py [0:0]


    def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, output_data_config_path=None, **kwargs):
        """
        Args:
            state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
            estimator (sagemaker.estimator.EstimatorBase): The estimator for the training step. Can be a `BYO estimator, Framework estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms.html>`_ or `Amazon built-in algorithm estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`_.
            job_name (str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
            data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms:

                * (str or Placeholder) - The S3 location where training data is saved.
                * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
                    channels for training data, you can specify a dict mapping channel names to
                    strings or :func:`~sagemaker.inputs.TrainingInput` objects.
                * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources that can
                    provide additional information about the training dataset. See
                    :func:`sagemaker.inputs.TrainingInput` for full details.
                * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
                    Amazon :class:`Record` objects serialized and stored in S3.
                    For use with an estimator for an Amazon algorithm.
                * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
                    :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
                    where each instance is a different channel of training data.
            hyperparameters: Parameters used for training.

                * (dict, optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
                    If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
                * (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
            mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
            experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
            wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
            tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
            output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
                artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
            parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
                You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders <https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
        """
        self.estimator = estimator
        self.job_name = job_name

        if wait_for_completion:
            """
            Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
            """

            kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
                                                                       SageMakerApi.CreateTrainingJob,
                                                                       IntegrationPattern.WaitForCompletion)
        else:
            """
            Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
            """

            kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
                                                                       SageMakerApi.CreateTrainingJob)
        # Convert `data` Placeholder to a JSONPath string because sagemaker.workflow.airflow.training_config does not
        # accept Placeholder in the `input` argument. We will suffix the 'S3Uri' key in `parameters` with ".$" later.
        is_data_placeholder = isinstance(data, Placeholder)
        if is_data_placeholder:
            data = data.to_jsonpath()

        if isinstance(job_name, str):
            training_parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
        else:
            training_parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)

        if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
            training_parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()

        if estimator.rules != None:
            training_parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]

        if isinstance(job_name, Placeholder):
            training_parameters['TrainingJobName'] = job_name

        if output_data_config_path is not None:
            training_parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path

        if data is not None and is_data_placeholder:
            # Replace the 'S3Uri' key with one that supports JSONpath value.
            # Support for uri str only: The list will only contain 1 element
            data_uri = training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
            training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri

        if hyperparameters is not None:
            if not isinstance(hyperparameters, Placeholder):
                if estimator.hyperparameters() is not None:
                    hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
            training_parameters['HyperParameters'] = hyperparameters

        if experiment_config is not None:
            training_parameters['ExperimentConfig'] = experiment_config

        if 'S3Operations' in training_parameters:
            del training_parameters['S3Operations']

        if tags:
            training_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

        if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
            # Update training parameters with input parameters
            merge_dicts(training_parameters, kwargs[Field.Parameters.value])

        kwargs[Field.Parameters.value] = training_parameters
        super(TrainingStep, self).__init__(state_id, **kwargs)