src/sagemaker/jumpstart/hub/hub.py (196 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. # pylint: skip-file """This module provides the JumpStart Hub class.""" from __future__ import absolute_import from datetime import datetime import logging from typing import Optional, Dict, List, Any, Union from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session from sagemaker.jumpstart.types import ( HubContentType, ) from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, construct_hub_arn_from_name, ) from sagemaker.jumpstart.notebook_utils import ( list_jumpstart_models, ) from sagemaker.jumpstart.hub.interfaces import ( DescribeHubResponse, DescribeHubContentResponse, ) from sagemaker.jumpstart.hub.constants import ( LATEST_VERSION_WILDCARD, ) from sagemaker.jumpstart import utils class Hub: """Class for creating and managing a curated JumpStart hub""" # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") _list_hubs_cache: List[Dict[str, Any]] = [] def __init__( self, hub_name: str, sagemaker_session: Session, bucket_name: Optional[str] = None, ) -> None: """Instantiates a SageMaker ``Hub``. Args: hub_name (str): The name of the Hub to create. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name self.bucket_name = bucket_name self._sagemaker_session = ( sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) ) def _get_latest_model_version(self, model_id: str) -> str: """Populates the lastest version of a model from specs no matter what is passed. Returns model ({ model_id: str, version: str }) """ model_specs = utils.verify_model_region_and_return_specs( model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region ) return model_specs.version def create( self, description: str, display_name: Optional[str] = None, search_keywords: Optional[str] = None, tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" curr_timestamp = datetime.now().timestamp() request = { "hub_name": self.hub_name, "hub_description": description, "hub_display_name": display_name, "hub_search_keywords": search_keywords, "tags": tags, } if self.bucket_name: request["s3_storage_config"] = { "S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}") } return self._sagemaker_session.create_hub(**request) def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( hub_name=self.hub_name if not hub_name else hub_name ) return hub_description def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: """List and paginate models from Hub.""" next_token: Optional[str] = None first_iteration: bool = True hub_model_summaries: List[Dict[str, Any]] = [] while first_iteration or next_token: first_iteration = False list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) hub_model_summaries.extend(list_hub_content_response.get("HubContentSummaries", [])) next_token = list_hub_content_response.get("NextToken") return hub_model_summaries def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: """Lists the models and model references in this SageMaker Hub. This function caches the models in local memory **kwargs: Passed to invocation of ``Session:list_hub_contents``. """ response = {} if clear_cache: self._list_hubs_cache = None if self._list_hubs_cache is None: hub_model_reference_summaries = self._list_and_paginate_models( **{ "hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL_REFERENCE.value, **kwargs, } ) hub_model_summaries = self._list_and_paginate_models( **{ "hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL.value, **kwargs, } ) response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries response["next_token"] = None # Temporary until pagination is implemented return response def list_sagemaker_public_hub_models( self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE), next_token: Optional[str] = None, ) -> Dict[str, Any]: """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. Args: filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). next_token (str): Optional. A token to resume pagination of list_inference_components. This is currently not implemented. """ response = {} jumpstart_public_hub_arn = construct_hub_arn_from_name( JUMPSTART_MODEL_HUB_NAME, self.region, self._sagemaker_session ) hub_content_summaries = [] models = list_jumpstart_models(filter=filter, list_versions=True) for model in models: if len(model) <= 63: info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) hub_model_arn = ( f"arn:{info.partition}:" f"sagemaker:{info.region}:" f"aws:hub-content/{info.hub_name}/" f"{HubContentType.MODEL.value}/{model[0]}" ) hub_content_summary = { "hub_content_name": model[0], "hub_content_arn": hub_model_arn, } hub_content_summaries.append(hub_content_summary) response["hub_content_summaries"] = hub_content_summaries response["next_token"] = None # Temporary until pagination is implemented for this function return response def delete(self) -> None: """Deletes this SageMaker Hub.""" return self._sagemaker_session.delete_hub(self.hub_name) def create_model_reference( self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None ): """Adds model reference to this SageMaker Hub.""" return self._sagemaker_session.create_hub_content_reference( hub_name=self.hub_name, source_hub_content_arn=model_arn, hub_content_name=model_name, min_version=min_version, ) def delete_model_reference(self, model_name: str) -> None: """Deletes model reference from this SageMaker Hub.""" return self._sagemaker_session.delete_hub_content_reference( hub_name=self.hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value, hub_content_name=model_name, ) def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: """Describe Model or ModelReference in a Hub.""" hub_name = hub_name or self.hub_name # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, hub_name=hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( hub_name=hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, ) except Exception as ex: logging.info( "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) # Failed to describe ModelReference, try with Model try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL.value, hub_name=hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = ( self._sagemaker_session.describe_hub_content( hub_name=hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL.value, ) ) except Exception as ex: # Failed with both, throw a custom error message raise RuntimeError( f"Cannot get details for {model_name} in Hub {hub_name}. \ {model_name} does not exist as a Model or ModelReference in {hub_name}: \n" + str(ex) ) return DescribeHubContentResponse(hub_content_description)