src/sagemaker/serve/model_server/tensorflow_serving/server.py (98 lines of code) (raw):
"""Module for Local Tensorflow Server"""
from __future__ import absolute_import
import requests
import logging
import platform
from pathlib import Path
from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
from sagemaker.session import Session
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url
from sagemaker import fw_utils
from sagemaker.serve.utils.uploader import upload
from sagemaker.local.utils import get_docker_host
logger = logging.getLogger(__name__)
class LocalTensorflowServing:
"""LocalTensorflowServing class."""
def _start_tensorflow_serving(
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
):
"""Starts a local tensorflow serving container.
Args:
client: Docker client
image: Image to use
model_path: Path to the model
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
self.container = client.containers.run(
image,
"serve",
detach=True,
auto_remove=True,
network_mode="host",
volumes={
Path(model_path): {
"bind": "/opt/ml/model",
"mode": "rw",
},
},
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
)
def _invoke_tensorflow_serving(self, request: object, content_type: str, accept: str):
"""Invokes a local tensorflow serving container.
Args:
request: Request to send
content_type: Content type to use
accept: Accept to use
"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
data=request,
headers={"Content-Type": content_type, "Accept": accept},
timeout=60, # this is what SageMaker Hosting uses as timeout
)
response.raise_for_status()
return response.content
except Exception as e:
raise Exception("Unable to send request to the local container server") from e
def _tensorflow_serving_deep_ping(self, predictor: PredictorBase):
"""Checks if the local tensorflow serving container is up and running.
If the container is not up and running, it will raise an exception.
"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
return (True, response)
# pylint: disable=broad-except
except Exception as e:
if "422 Client Error: Unprocessable Entity for url" in str(e):
raise LocalModelInvocationException(str(e))
return (False, response)
return (True, response)
class SageMakerTensorflowServing:
"""SageMakerTensorflowServing class."""
def _upload_tensorflow_serving_artifacts(
self,
model_path: str,
sagemaker_session: Session,
secret_key: str,
s3_model_data_url: str = None,
image: str = None,
should_upload_artifacts: bool = False,
):
"""Uploads the model artifacts to S3.
Args:
model_path: Path to the model
sagemaker_session: SageMaker session
secret_key: Secret key to use for authentication
s3_model_data_url: S3 model data URL
image: Image to use
model_data_s3_path: S3 model data URI
"""
s3_upload_path = None
if _is_s3_uri(model_path):
s3_upload_path = model_path
elif should_upload_artifacts:
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)
logger.debug(
"Uploading the model resources to bucket=%s, key_prefix=%s.",
bucket,
code_key_prefix,
)
s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix)
logger.debug("Model resources uploaded to: %s", s3_upload_path)
env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars