src/sagemaker/serve/mode/sagemaker_endpoint_mode.py (130 lines of code) (raw):
"""Placeholder docstring"""
from __future__ import absolute_import
from pathlib import Path
import logging
from typing import Type
from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.model_server.triton.server import SageMakerTritonServer
from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe
from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing
from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing
from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer
from sagemaker.serve.model_server.smd.server import SageMakerSmdServer
logger = logging.getLogger(__name__)
# pylint: disable=R0901
class SageMakerEndpointMode(
SageMakerTorchServe,
SageMakerTritonServer,
SageMakerDjlServing,
SageMakerTgiServing,
SageMakerMultiModelServer,
SageMakerTensorflowServing,
SageMakerSmdServer,
):
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServer):
super().__init__()
# pylint: disable=bad-super-call
super(SageMakerTritonServer, self).__init__()
self.inference_spec = inference_spec
self.model_server = model_server
self._tei_serving = SageMakerTeiServing()
def load(self, model_path: str):
"""Placeholder docstring"""
path = Path(model_path)
if not path.exists():
raise Exception("model_path does not exist")
if not path.is_dir():
raise Exception("model_path is not a valid directory")
model_dir = path.joinpath("model")
return self.inference_spec.model_fn(str(model_dir))
def prepare(
self,
model_path: str,
secret_key: str,
s3_model_data_url: str = None,
sagemaker_session: Session = None,
image: str = None,
jumpstart: bool = False,
should_upload_artifacts: bool = False,
):
"""Placeholder docstring"""
try:
sagemaker_session = sagemaker_session or Session()
except Exception as e:
raise Exception(
"Failed to setup default SageMaker session. Please allow a default "
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
) from e
upload_artifacts = None, None
if self.model_server == ModelServer.TORCHSERVE:
upload_artifacts = self._upload_torchserve_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.TRITON:
upload_artifacts = self._upload_triton_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.DJL_SERVING:
upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if self.model_server == ModelServer.TENSORFLOW_SERVING:
upload_artifacts = self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
# By default, we do not want to upload artifacts in S3 for the below server.
# In Case of Optimization, artifacts need to be uploaded into s3.
# In that case, `should_upload_artifacts` arg needs to come from
# the caller of prepare.
if self.model_server == ModelServer.TGI:
upload_artifacts = self._upload_tgi_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
jumpstart=jumpstart,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.MMS:
upload_artifacts = self._upload_server_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
secret_key=secret_key,
image=image,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.TEI:
upload_artifacts = self._tei_serving._upload_tei_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=should_upload_artifacts,
)
if self.model_server == ModelServer.SMD:
upload_artifacts = self._upload_smd_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
)
if upload_artifacts or isinstance(self.model_server, ModelServer):
return upload_artifacts
raise ValueError("%s model server is not supported" % self.model_server)