src/sagemaker/jumpstart/model.py (584 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module stores JumpStart implementation of Model class.""" from __future__ import absolute_import from typing import Callable, Dict, List, Optional, Any, Union import pandas as pd from botocore.exceptions import ClientError from sagemaker_core.shapes import ModelAccessConfig from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, get_proprietary_model_subscription_error, get_proprietary_model_subscription_msg, ) from sagemaker.jumpstart.factory.model import ( get_default_predictor, get_deploy_kwargs, get_init_kwargs, get_register_kwargs, ) from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.jumpstart.types import ( JumpStartSerializablePayload, DeploymentConfigMetadata, ) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, get_jumpstart_configs, get_metrics_from_deployment_configs, add_instance_rate_stats_to_benchmark_metrics, deployment_config_response_data, _deployment_config_lru_cache, _add_model_access_configs_to_model_data_sources, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.model_card import ( ModelCard, ModelPackageModelCard, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, ModelPackage, ) from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.predictor import PredictorBase from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.model_life_cycle import ModelLifeCycle from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements class JumpStartModel(Model): """JumpStartModel class. This class sets defaults based on the model ID and version. """ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, enable_network_isolation: Union[bool, PipelineVariable] = None, model_kms_key: Optional[str] = None, image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, source_dir: Optional[str] = None, code_location: Optional[str] = None, entry_point: Optional[str] = None, container_log_level: Optional[Union[int, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, ): """Initializes a ``JumpStartModel``. This method sets model-specific defaults for the ``Model.__init__`` method. Only model ID is required to instantiate this class, however any field can be overriden. Any field set to ``None`` does not get passed to the parent class method. Args: model_id (Optional[str]): JumpStart model ID to use. See https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known security vulnerabilities. (Default: None). tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: None). region (Optional[str]): The AWS region in which to launch the model. (Default: None). instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting endpoint. (Default: None). image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None). model_data (Optional[Union[str, PipelineVariable, dict]]): Location of SageMaker model data. (Default: None). role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role if it needs to access some AWS resources. It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). env (Optional[dict[str, str] or dict[str, PipelineVariable]]): Environment variables to run with ``image_uri`` when hosted in SageMaker. (Default: None). name (Optional[str]): The model name. If None, a default model name will be selected on each ``deploy``. (Default: None). vpc_config (Optional[Union[dict[str, list[str]],dict[str, list[PipelineVariable]]]]): The VpcConfig set on the model (Default: None) * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. (Default: None). sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, used for SageMaker interactions (Default: None). If not specified, one is created using the default AWS configuration chain. (Default: None). enable_network_isolation (Optional[Union[bool, PipelineVariable]]): If True, enables network isolation in the endpoint, isolating the model container. No inbound or outbound network calls can be made to or from the model container. (Default: None). model_kms_key (Optional[str]): KMS key ARN used to encrypt the repacked model archive file if the model is repacked. (Default: None). image_config (Optional[Union[dict[str, str], dict[str, PipelineVariable]]]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (Default: None). source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must point to a file with name ``sourcedir.tar.gz``. Structure within this directory is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). .. admonition:: Example With the following GitHub repo directory structure: >>> |----- README.md >>> |----- src >>> |----- inference.py >>> |----- test.py You can assign entry_point='inference.py', source_dir='src'. code_location (Optional[str]): Name of the S3 bucket where custom code is uploaded (Default: None). If not specified, the default bucket created by ``sagemaker.session.Session`` is used. (Default: None). entry_point (Optional[str]): The absolute or relative path to the local Python source file that should be executed as the entry point to model hosting. (Default: None). If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in the Git repo. (Default: None). .. admonition:: Example With the following GitHub repo directory structure: >>> |----- README.md >>> |----- src >>> |----- inference.py >>> |----- test.py You can assign entry_point='src/inference.py'. container_log_level (Optional[Union[int, PipelineVariable]]): Log level to use within the container. Valid values are defined in the Python logging module. (Default: None). dependencies (Optional[list[str]]): A list of absolute or relative paths to directories with any additional libraries that should be exported to the container (default: []). The library folders are copied to SageMaker in the same folder where the entrypoint is copied. If 'git_config' is provided, 'dependencies' should be a list of relative locations to directories with any additional libraries needed in the Git repo. If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used instead. This is not supported with "local code" in Local Mode. (Default: None). .. admonition:: Example The following call >>> Model(entry_point='inference.py', ... dependencies=['my/libs/common', 'virtual-env']) results in the following structure inside the container: >>> $ ls >>> opt/ml/code >>> |------ inference.py >>> |------ common >>> |------ virtual-env git_config (Optional[dict[str, str]]): Git configurations used for cloning files, including ``repo``, ``branch``, ``commit``, ``2FA_enabled``, ``username``, ``password`` and ``token``. The ``repo`` field is required. All other fields are optional. ``repo`` specifies the Git repository where your training script is stored. If you don't provide ``branch``, the default value 'master' is used. If you don't provide ``commit``, the latest commit in the specified branch is used. ``2FA_enabled``, ``username``, ``password`` and ``token`` are used for authentication. For GitHub (or other Git) accounts, set ``2FA_enabled`` to 'True' if two-factor authentication is enabled for the account, otherwise set it to 'False'. If you do not provide a value for ``2FA_enabled``, a default value of 'False' is used. CodeCommit does not support two-factor authentication, so do not provide "2FA_enabled" with CodeCommit repositories. For GitHub and other Git repos, when SSH URLs are provided, it doesn't matter whether 2FA is enabled or disabled. You should either have no passphrase for the SSH key pairs or have the ssh-agent configured so that you will not be prompted for the SSH passphrase when you run the 'git clone' command with SSH URLs. When HTTPS URLs are provided, if 2FA is disabled, then either ``token`` or ``username`` and ``password`` are be used for authentication if provided. ``Token`` is prioritized. If 2FA is enabled, only ``token`` is used for authentication if provided. If required authentication info is not provided, the SageMaker Python SDK attempts to use local credentials to authenticate. If that fails, an error message is thrown. For CodeCommit repos, 2FA is not supported, so ``2FA_enabled`` should not be provided. There is no token in CodeCommit, so ``token`` should also not be provided. When ``repo`` is an SSH URL, the requirements are the same as GitHub repos. When ``repo`` is an HTTPS URL, ``username`` and ``password`` are used for authentication if they are provided. If they are not provided, the SageMaker Python SDK attempts to use either the CodeCommit credential helper or local credential storage for authentication. (Default: None). .. admonition:: Example The following config results in cloning the repo specified in 'repo', then checking out the 'master' branch, and checking out the specified commit. >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', >>> 'branch': 'test-branch-git-config', >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} model_package_arn (Optional[str]): An existing SageMaker Model Package arn, can be just the name if your account owns the Model Package. ``model_data`` is not required. (Default: None). resources (Optional[ResourceRequirements]): The compute resource requirements for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). config_name (Optional[str]): The name of the JumpStart config that can be optionally applied to the model. additional_model_data_sources (Optional[Dict[str, Any]]): Additional location of SageMaker model data (default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. """ hub_arn = None if hub_name: hub_arn = generate_hub_arn_for_init_kwargs( hub_name=hub_name, region=region, session=sagemaker_session ) def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, model_type=self.model_type, model_version=model_version, hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, region=region, image_uri=image_uri, model_data=model_data, source_dir=source_dir, entry_point=entry_point, env=env, predictor_cls=predictor_cls, role=role, name=name, vpc_config=vpc_config, sagemaker_session=sagemaker_session, enable_network_isolation=enable_network_isolation, model_kms_key=model_kms_key, image_config=image_config, code_location=code_location, container_log_level=container_log_level, dependencies=dependencies, git_config=git_config, model_package_arn=model_package_arn, resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, ) self.orig_predictor_cls = predictor_cls self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session self.role = role self.config_name = model_init_kwargs.config_name self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) self._metadata_configs = get_jumpstart_configs( region=self.region, model_id=self.model_id, model_version=self.model_version, sagemaker_session=self.sagemaker_session, model_type=self.model_type, hub_arn=self.hub_arn, ) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( region=self.region, model_id=self.model_id, version=self.model_version, hub_arn=self.hub_arn, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, ).model_subscription_link JUMPSTART_LOGGER.warning( get_proprietary_model_subscription_msg(self.model_id, subscription_link) ) def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. Raises: NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. VulnerableJumpStartModelError: If any of the dependencies required by the script have known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ return payloads.retrieve_all_examples( model_id=self.model_id, model_version=self.model_version, hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, ) def retrieve_example_payload(self) -> JumpStartSerializablePayload: """Returns the example payload associated with the model. Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. Raises: NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. VulnerableJumpStartModelError: If any of the dependencies required by the script have known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, hub_arn=self.hub_arn, model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, ) def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: config_name (str): The name of the deployment config to apply to the model. Call list_deployment_configs to see the list of config names. instance_type (str): The instance_type that the model will use after setting the config. """ self.__init__( model_id=self.model_id, model_version=self.model_version, instance_type=instance_type, config_name=config_name, sagemaker_session=self.sagemaker_session, role=self.role, ) @property def deployment_config(self) -> Optional[Dict[str, Any]]: """The deployment config that will be applied to ``This`` model. Returns: Optional[Dict[str, Any]]: Deployment config. """ if self.config_name is None: return None for config in self.list_deployment_configs(): if config.get("DeploymentConfigName") == self.config_name: return config return None @property def benchmark_metrics(self) -> pd.DataFrame: """Benchmark Metrics for deployment configs. Returns: Benchmark Metrics: Pandas DataFrame object. """ df = pd.DataFrame(self._get_deployment_configs_benchmarks_data()) blank_index = [""] * len(df) df.index = blank_index return df def display_benchmark_metrics(self, **kwargs) -> None: """Display deployment configs benchmark metrics.""" df = self.benchmark_metrics instance_type = kwargs.get("instance_type") if instance_type: df = df[df["Instance Type"].str.contains(instance_type)] print(df.to_markdown(index=False, floatfmt=".2f")) def list_deployment_configs(self) -> List[Dict[str, Any]]: """List deployment configs for ``This`` model. Returns: List[Dict[str, Any]]: A list of deployment configs. """ return deployment_config_response_data( self._get_deployment_configs(self.config_name, self.instance_type) ) @classmethod def attach( cls, endpoint_name: str, inference_component_name: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, hub_name: Optional[str] = None, ) -> "JumpStartModel": """Attaches a JumpStartModel object to an existing SageMaker Endpoint. The model id, version (and inference component name) can be inferred from the tags. """ inferred_model_id = inferred_model_version = inferred_inference_component_name = None if inference_component_name is None or model_id is None or model_version is None: ( inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _, ) = get_model_info_from_endpoint( endpoint_name=endpoint_name, inference_component_name=inference_component_name, sagemaker_session=sagemaker_session, ) model_id = model_id or inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name model = JumpStartModel( model_id=model_id, model_version=model_version, sagemaker_session=sagemaker_session, hub_name=hub_name, ) model.endpoint_name = endpoint_name model.inference_component_name = inference_component_name return model def _create_sagemaker_model( self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None, **kwargs, ): """Create a SageMaker Model Entity Args: instance_type (str): Optional. The EC2 instance type that this Model will be used for, this is only used to determine if the image needs GPU support or not. (Default: None). accelerator_type (str): Optional. Type of Elastic Inference accelerator to attach to an endpoint for model loading and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic Inference accelerator will be attached to the endpoint. (Default: None). tags (Optional[Tags]): Optional. The list of tags to add to the model. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags (Default: None). serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Optional. Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. (Default: None). kwargs: Keyword arguments coming from the caller. This class does not require any so they are ignored. """ tags = format_tags(tags) # if the user inputs a model artifact uri, do not use model package arn to create # inference endpoint. if self.model_package_arn and not self._model_data_is_set: # When a ModelPackageArn is provided we just create the Model model_package = ModelPackage( role=self.role, model_data=self.model_data, model_package_arn=self.model_package_arn, sagemaker_session=self.sagemaker_session, predictor_cls=self.predictor_cls, vpc_config=self.vpc_config, ) if self.name is not None: model_package.name = self.name if self.env is not None: model_package.env = self.env model_package._create_sagemaker_model( instance_type=instance_type, accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, **kwargs, ) if self._base_name is None and model_package._base_name is not None: self._base_name = model_package._base_name if self.name is None and model_package.name is not None: self.name = model_package.name else: super(JumpStartModel, self)._create_sagemaker_model( instance_type=instance_type, accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, **kwargs, ) def deploy( self, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, serializer: Optional[BaseSerializer] = None, deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, data_capture_config: Optional[DataCaptureConfig] = None, async_inference_config: Optional[AsyncInferenceConfig] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None, volume_size: Optional[int] = None, model_data_download_timeout: Optional[int] = None, container_startup_health_check_timeout: Optional[int] = None, inference_recommendation_id: Optional[str] = None, explainer_config: Optional[ExplainerConfig] = None, accept_eula: Optional[bool] = None, endpoint_logging: Optional[bool] = False, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, inference_ami_version: Optional[str] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``. Any field set to ``None`` does not get passed to the parent class method. Args: initial_instance_count (Optional[int]): The initial number of instances to run in the ``Endpoint`` created from this ``Model``. If not using serverless inference or the model has not called ``right_size()``, then it need to be a number larger or equals to 1. (Default: None) instance_type (Optional[str]): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge', or 'local' for local mode. If not using serverless inference or the model has not called ``right_size()``, then it is required to deploy a model. (Default: None) serializer (Optional[:class:`~sagemaker.serializers.BaseSerializer`]): A serializer object, used to encode data for an inference endpoint (Default: None). If ``serializer`` is not None, then ``serializer`` will override the default serializer. The default serializer is set by the ``predictor_cls``. (Default: None). deserializer (Optional[:class:`~sagemaker.deserializers.BaseDeserializer`]): A deserializer object, used to decode data from an inference endpoint (Default: None). If ``deserializer`` is not None, then ``deserializer`` will override the default deserializer. The default deserializer is set by the ``predictor_cls``. (Default: None). accelerator_type (Optional[str]): Type of Elastic Inference accelerator to deploy this model for model loading and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic Inference accelerator will be attached to the endpoint. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html (Default: None). endpoint_name (Optional[str]): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. (Default: None). tags (Optional[Tags]): Tags to attach to this specific endpoint. (Default: None). kms_key (Optional[str]): The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. (Default: None). wait (Optional[bool]): Whether the call should wait until the deployment of this model completes. (Default: True). data_capture_config (Optional[sagemaker.model_monitor.DataCaptureConfig]): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. (Default: None). async_inference_config (Optional[sagemaker.model_monitor.AsyncInferenceConfig]): Specifies configuration related to async endpoint. Use this configuration when trying to create async endpoint and make async inference. If empty config object passed through, will use default config to deploy async endpoint. Deploy a real-time endpoint if it's None. (Default: None) serverless_inference_config (Optional[sagemaker.serverless.ServerlessInferenceConfig]): Specifies configuration related to serverless endpoint. Use this configuration when trying to create serverless endpoint and make serverless inference. If empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. (Default: None) volume_size (Optional[int]): The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Currenly only Amazon EBS gp2 storage volumes are supported. (Default: None). model_data_download_timeout (Optional[int]): The timeout value, in seconds, to download and extract model data from Amazon S3 to the individual inference instance associated with this production variant. (Default: None). container_startup_health_check_timeout (Optional[int]): The timeout value, in seconds, for your inference container to pass health check by SageMaker Hosting. For more information about health check see: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests (Default: None). inference_recommendation_id (Optional[str]): The recommendation id which specifies the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. (Default: None). explainer_config (Optional[sagemaker.explainer.ExplainerConfig]): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (Default: None). accept_eula (bool): For models that require a Model Access Config, specify True or False to indicate whether model terms of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). endpoint_logging (Optiona[bool]): If set to true, live logging will be emitted as the SageMaker Endpoint starts up. (Default: False). resources (Optional[ResourceRequirements]): The compute resource requirements for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). managed_instance_scaling (Optional[Dict]): Managed intance scaling options, if configured Amazon SageMaker will manage the instance number behind the endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). routing_config (Optional[Dict]): Settings the control how the endpoint routes incoming traffic to the instances that the endpoint hosts. model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. """ deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, region=self.region, hub_arn=self.hub_arn, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, initial_instance_count=initial_instance_count, instance_type=instance_type or self.instance_type, serializer=serializer, deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, accept_eula=accept_eula, model_reference_arn=self.model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, config_name=self.config_name, routing_config=routing_config, model_access_configs=model_access_configs, inference_ami_version=inference_ami_version, ) if ( self.model_type == JumpStartModelType.PROPRIETARY and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED ): raise ValueError( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) # No resources given to deploy() but present 'resources' key in deploy_kwargs means default # JumpStart resource requirements are being used if hasattr(self, "_is_sharded_model") and not resources and deploy_kwargs.resources: if ( self._is_sharded_model and deploy_kwargs.resources.num_cpus and deploy_kwargs.resources.num_cpus > 0 ): JUMPSTART_LOGGER.warning( "NumOfCpuCoresRequired should be 0 for the best experience with SageMaker Fast " "Model Loading. Overriding the requested `num_cpus` to 0." ) deploy_kwargs.resources.num_cpus = 0 self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( self.additional_model_data_sources, deploy_kwargs.model_access_configs, deploy_kwargs.model_id, deploy_kwargs.region, ) try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) except ClientError as e: subscription_link = verify_model_region_and_return_specs( region=self.region, model_id=self.model_id, version=self.model_version, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, config_name=self.config_name, hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: return get_default_predictor( predictor=predictor, model_id=self.model_id, model_version=self.model_version, hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor def register( self, content_types: List[Union[str, PipelineVariable]] = None, response_types: List[Union[str, PipelineVariable]] = None, inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, model_package_group_name: Optional[Union[str, PipelineVariable]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, model_metrics: Optional[ModelMetrics] = None, metadata_properties: Optional[MetadataProperties] = None, approval_status: Optional[Union[str, PipelineVariable]] = None, description: Optional[str] = None, drift_check_baselines: Optional[DriftCheckBaselines] = None, customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, validation_specification: Optional[Union[str, PipelineVariable]] = None, domain: Optional[Union[str, PipelineVariable]] = None, task: Optional[Union[str, PipelineVariable]] = None, sample_payload_url: Optional[Union[str, PipelineVariable]] = None, framework: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[Union[str, PipelineVariable]] = None, nearest_model_name: Optional[Union[str, PipelineVariable]] = None, data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, accept_eula: Optional[bool] = None, model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. Args: content_types (list[str] or list[PipelineVariable]): The supported MIME types for the input data. response_types (list[str] or list[PipelineVariable]): The supported MIME types for the output data. inference_instances (list[str] or list[PipelineVariable]): A list of the instance types that are used to generate inferences in real-time (default: None). transform_instances (list[str] or list[PipelineVariable]): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). model_package_group_name (str or PipelineVariable): Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned. Defaults to ``None``. image_uri (str or PipelineVariable): Inference image URI for the container. Model class' self.image will be used if it is None. Defaults to ``None``. model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``. metadata_properties (MetadataProperties): MetadataProperties object. Defaults to ``None``. approval_status (str or PipelineVariable): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval". Defaults to ``PendingManualApproval``. description (str): Model Package description. Defaults to ``None``. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]): A dictionary of key-value paired metadata properties (default: None). domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). sample_payload_url (str or PipelineVariable): The S3 path where the sample payload is stored (default: None). task (str or PipelineVariable): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). framework (str or PipelineVariable): Machine learning framework of the model package container image (default: None). framework_version (str or PipelineVariable): Framework version of the Model Package Container Image (default: None). nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str or PipelineVariable): Input object for the model (default: None). skip_model_validation (str or PipelineVariable): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). model_card (ModeCard or ModelPackageModelCard): document contains qualitative and quantitative information about a model (default: None). accept_eula (bool): For models that require a Model Access Config, specify True or False to indicate whether model terms of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. """ if model_package_group_name is None: model_package_group_name = self.model_id if self.model_type is JumpStartModelType.PROPRIETARY: source_uri = self.model_package_arn register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, hub_arn=self.hub_arn, model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, supported_content_types=content_types, response_types=response_types, inference_instances=inference_instances, transform_instances=transform_instances, model_package_group_name=model_package_group_name, image_uri=image_uri, model_metrics=model_metrics, metadata_properties=metadata_properties, approval_status=approval_status, description=description, drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, task=task, sample_payload_url=sample_payload_url, framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, config_name=self.config_name, model_card=model_card, accept_eula=accept_eula, model_life_cycle=model_life_cycle, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) def register_deploy_wrapper(*args, **kwargs): if self.model_package_arn is not None: return self.deploy(*args, **kwargs) self.model_package_arn = model_package.model_package_arn predictor = self.deploy(*args, **kwargs) self.model_package_arn = None return predictor model_package.deploy = register_deploy_wrapper return model_package @_deployment_config_lru_cache def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: """Deployment configs benchmark metrics. Returns: Dict[str, List[str]]: Deployment config benchmark data. """ return get_metrics_from_deployment_configs( self._get_deployment_configs(None, None), ) @_deployment_config_lru_cache def _get_deployment_configs( self, selected_config_name: Optional[str], selected_instance_type: Optional[str] ) -> List[DeploymentConfigMetadata]: """Retrieve deployment configs metadata. Args: selected_config_name (Optional[str]): The name of the selected deployment config. selected_instance_type (Optional[str]): The selected instance type. """ deployment_configs = [] if not self._metadata_configs: return deployment_configs err = None for config_name, metadata_config in self._metadata_configs.items(): if selected_config_name == config_name: instance_type_to_use = selected_instance_type else: instance_type_to_use = metadata_config.resolved_config.get( "default_inference_instance_type" ) if metadata_config.benchmark_metrics: ( err, metadata_config.benchmark_metrics, ) = add_instance_rate_stats_to_benchmark_metrics( self.region, metadata_config.benchmark_metrics ) config_components = metadata_config.config_components.get(config_name) image_uri = ( ( config_components.hosting_instance_type_variants.get("regional_aliases", {}) .get(self.region, {}) .get("alias_ecr_uri_1") ) if config_components else self.image_uri ) init_kwargs = get_init_kwargs( config_name=config_name, model_id=self.model_id, instance_type=instance_type_to_use, sagemaker_session=self.sagemaker_session, image_uri=image_uri, region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, ) deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, instance_type=instance_type_to_use, sagemaker_session=self.sagemaker_session, region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, ) deployment_config_metadata = DeploymentConfigMetadata( config_name, metadata_config, init_kwargs, deploy_kwargs, ) deployment_configs.append(deployment_config_metadata) if err and err["Code"] == "AccessDeniedException": error_message = "Instance rate metrics will be omitted. Reason: %s" JUMPSTART_LOGGER.warning(error_message, err["Message"]) return deployment_configs def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self)