in src/stepfunctions/steps/sagemaker.py [0:0]
def __init__(self, state_id, model, model_name=None, instance_type=None, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
tags (list[dict] or Placeholders, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None)
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders <https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
"""
if isinstance(model, FrameworkModel):
model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
if model_name:
model_parameters['ModelName'] = model_name
elif isinstance(model, Model):
model_parameters = {
'ExecutionRoleArn': model.role,
'ModelName': model_name or model.name,
'PrimaryContainer': {
'Environment': model.env,
'Image': model.image_uri,
'ModelDataUrl': model.model_data
}
}
else:
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
if 'S3Operations' in model_parameters:
del model_parameters['S3Operations']
if tags:
model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update model parameters with input parameters
merge_dicts(model_parameters, kwargs[Field.Parameters.value])
kwargs[Field.Parameters.value] = model_parameters
"""
Example resource arn: arn:aws:states:::sagemaker:createModel
"""
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateModel)
super(ModelStep, self).__init__(state_id, **kwargs)