tensorflow/inference/docker/build_artifacts/sagemaker/serve.py (420 lines of code) (raw):
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import boto3
import logging
import os
import re
import signal
import subprocess
import tfs_utils
from contextlib import contextmanager
logging.basicConfig(
format="%(process)d %(asctime)s %(levelname)-8s %(message)s", force=True, level=logging.INFO
)
log = logging.getLogger(__name__)
JS_PING = "js_content tensorflowServing.ping"
JS_INVOCATIONS = "js_content tensorflowServing.invocations"
GUNICORN_PING = "proxy_pass http://gunicorn_upstream/ping"
GUNICORN_INVOCATIONS = "proxy_pass http://gunicorn_upstream/invocations"
CODE_DIR = (
"/opt/ml/code"
if os.environ.get("SAGEMAKER_MULTI_MODEL", "False").lower() == "true"
else "/opt/ml/model/code"
)
PYTHON_LIB_PATH = os.path.join(CODE_DIR, "lib")
REQUIREMENTS_PATH = os.path.join(CODE_DIR, "requirements.txt")
INFERENCE_PATH = os.path.join(CODE_DIR, "inference.py")
class ServiceManager(object):
def __init__(self):
self._state = "initializing"
self._nginx = None
self._tfs = []
self._gunicorn = None
self._gunicorn_command = None
self._gunicorn_env = None
self._enable_python_service = False
self._tfs_version = os.environ.get("SAGEMAKER_TFS_VERSION", "1.13")
self._nginx_http_port = os.environ.get("SAGEMAKER_BIND_TO_PORT", "8080")
self._nginx_loglevel = os.environ.get("SAGEMAKER_TFS_NGINX_LOGLEVEL", "error")
self._tfs_default_model_name = os.environ.get("SAGEMAKER_TFS_DEFAULT_MODEL_NAME", "None")
self._sagemaker_port_range = os.environ.get("SAGEMAKER_SAFE_PORT_RANGE", None)
self._gunicorn_workers = os.environ.get("SAGEMAKER_GUNICORN_WORKERS", 1)
self._gunicorn_threads = os.environ.get("SAGEMAKER_GUNICORN_THREADS", 1)
self._gunicorn_loglevel = os.environ.get("SAGEMAKER_GUNICORN_LOGLEVEL", "info")
self._tfs_config_path = "/sagemaker/model-config.cfg"
self._tfs_batching_config_path = "/sagemaker/batching-config.cfg"
_enable_batching = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower()
_enable_multi_model_endpoint = os.environ.get("SAGEMAKER_MULTI_MODEL", "false").lower()
# Use this to specify memory that is needed to initialize CUDA/cuDNN and other GPU libraries
self._tfs_gpu_margin = float(os.environ.get("SAGEMAKER_TFS_FRACTIONAL_GPU_MEM_MARGIN", 0.2))
self._tfs_instance_count = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", 1))
self._tfs_wait_time_seconds = int(
os.environ.get("SAGEMAKER_TFS_WAIT_TIME_SECONDS", 55 // self._tfs_instance_count)
)
self._tfs_inter_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTER_OP_PARALLELISM", 0)
self._tfs_intra_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTRA_OP_PARALLELISM", 0)
self._gunicorn_worker_class = os.environ.get("SAGEMAKER_GUNICORN_WORKER_CLASS", "gevent")
self._gunicorn_timeout_seconds = int(
os.environ.get("SAGEMAKER_GUNICORN_TIMEOUT_SECONDS", 30)
)
self._nginx_proxy_read_timeout_seconds = int(
os.environ.get("SAGEMAKER_NGINX_PROXY_READ_TIMEOUT_SECONDS", 60)
)
# Nginx proxy read timeout should not be less than the GUnicorn timeout. If it is, this
# can result in upstream time out errors.
if self._gunicorn_timeout_seconds > self._nginx_proxy_read_timeout_seconds:
log.info(
"GUnicorn timeout was higher than Nginx proxy read timeout."
" Setting Nginx proxy read timeout from {} seconds to {} seconds"
" to match GUnicorn timeout.".format(
self._nginx_proxy_read_timeout_seconds, self._gunicorn_timeout_seconds
)
)
self._nginx_proxy_read_timeout_seconds = self._gunicorn_timeout_seconds
if os.environ.get("OMP_NUM_THREADS") is None:
os.environ["OMP_NUM_THREADS"] = "1"
if _enable_multi_model_endpoint not in ["true", "false"]:
raise ValueError("SAGEMAKER_MULTI_MODEL must be 'true' or 'false'")
self._tfs_enable_multi_model_endpoint = _enable_multi_model_endpoint == "true"
self._need_python_service()
log.info("PYTHON SERVICE: {}".format(str(self._enable_python_service)))
if _enable_batching not in ["true", "false"]:
raise ValueError("SAGEMAKER_TFS_ENABLE_BATCHING must be 'true' or 'false'")
self._tfs_enable_batching = _enable_batching == "true"
if _enable_multi_model_endpoint not in ["true", "false"]:
raise ValueError("SAGEMAKER_MULTI_MODEL must be 'true' or 'false'")
self._tfs_enable_multi_model_endpoint = _enable_multi_model_endpoint == "true"
self._use_gunicorn = self._enable_python_service or self._tfs_enable_multi_model_endpoint
if self._sagemaker_port_range is not None:
parts = self._sagemaker_port_range.split("-")
low = int(parts[0])
hi = int(parts[1])
self._tfs_grpc_ports = []
self._tfs_rest_ports = []
if low + 2 * self._tfs_instance_count > hi:
raise ValueError(
"not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})".format(
self._sagemaker_port_range
)
)
# select non-overlapping grpc and rest ports based on tfs instance count
for i in range(self._tfs_instance_count):
self._tfs_grpc_ports.append(str(low + 2 * i))
self._tfs_rest_ports.append(str(low + 2 * i + 1))
# concat selected ports respectively in order to pass them to python service
self._tfs_grpc_concat_ports = self._concat_ports(self._tfs_grpc_ports)
self._tfs_rest_concat_ports = self._concat_ports(self._tfs_rest_ports)
else:
# just use the standard default ports
self._tfs_grpc_ports = ["9000"]
self._tfs_rest_ports = ["8501"]
# provide single concat port here for default case
self._tfs_grpc_concat_ports = "9000"
self._tfs_rest_concat_ports = "8501"
# set environment variable for python service
os.environ["TFS_GRPC_PORTS"] = self._tfs_grpc_concat_ports
os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports
def _need_python_service(self):
if (
os.path.exists(INFERENCE_PATH)
or os.path.exists(REQUIREMENTS_PATH)
or os.path.exists(PYTHON_LIB_PATH)
):
self._enable_python_service = True
if os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET") and os.environ.get(
"SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"
):
self._enable_python_service = True
def _concat_ports(self, ports):
str_ports = [str(port) for port in ports]
concat_str_ports = ",".join(str_ports)
return concat_str_ports
def _create_tfs_config(self):
models = tfs_utils.find_models()
if not models:
raise ValueError("no SavedModel bundles found!")
if self._tfs_default_model_name == "None":
default_model = os.path.basename(models[0])
if default_model:
self._tfs_default_model_name = default_model
log.info("using default model name: {}".format(self._tfs_default_model_name))
else:
log.info("no default model detected")
# config (may) include duplicate 'config' keys, so we can't just dump a dict
config = "model_config_list: {\n"
for m in models:
config += " config: {\n"
config += " name: '{}'\n".format(os.path.basename(m))
config += " base_path: '{}'\n".format(m)
config += " model_platform: 'tensorflow'\n"
config += " model_version_policy: {\n"
config += " specific: {\n"
for version in tfs_utils.find_model_versions(m):
config += " versions: {}\n".format(version)
config += " }\n"
config += " }\n"
config += " }\n"
config += "}\n"
log.info("tensorflow serving model config: \n%s\n", config)
with open(self._tfs_config_path, "w", encoding="utf8") as f:
f.write(config)
def _setup_gunicorn(self):
python_path_content = []
python_path_option = ""
bucket = os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET", None)
prefix = os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX", None)
if not os.path.exists(CODE_DIR) and bucket and prefix:
self._download_scripts(bucket, prefix)
if self._enable_python_service:
lib_path_exists = os.path.exists(PYTHON_LIB_PATH)
requirements_exists = os.path.exists(REQUIREMENTS_PATH)
python_path_content = ["/opt/ml/model/code"]
python_path_option = "--pythonpath "
if lib_path_exists:
python_path_content.append(PYTHON_LIB_PATH)
if requirements_exists:
if lib_path_exists:
log.warning(
"loading modules in '{}', ignoring requirements.txt".format(PYTHON_LIB_PATH)
)
else:
log.info("installing packages from requirements.txt...")
pip_install_cmd = "pip3 install -r {}".format(REQUIREMENTS_PATH)
try:
subprocess.check_call(pip_install_cmd.split())
except subprocess.CalledProcessError:
log.error("failed to install required packages, exiting.")
self._stop()
raise ChildProcessError("failed to install required packages.")
gunicorn_command = (
"python3 /sagemaker/python_service.py -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker "
"--workers {} --threads {} --log-level {} --timeout {} "
).format(
self._gunicorn_worker_class,
self._gunicorn_workers,
self._gunicorn_threads,
self._gunicorn_loglevel,
self._gunicorn_timeout_seconds,
)
log.info("gunicorn command: {}".format(gunicorn_command))
self._gunicorn_command = gunicorn_command
gunicorn_env = {
"TFS_GRPC_PORTS": self._tfs_grpc_concat_ports,
"TFS_REST_PORTS": self._tfs_rest_concat_ports,
"SAGEMAKER_MULTI_MODEL": str(self._tfs_enable_multi_model_endpoint),
"SAGEMAKER_TFS_WAIT_TIME_SECONDS": str(self._tfs_wait_time_seconds),
"SAGEMAKER_TFS_INTER_OP_PARALLELISM": str(self._tfs_inter_op_parallelism),
"SAGEMAKER_TFS_INTRA_OP_PARALLELISM": str(self._tfs_intra_op_parallelism),
"SAGEMAKER_TFS_INSTANCE_COUNT": str(self._tfs_instance_count),
"PYTHONPATH": ":".join(python_path_content),
"SAGEMAKER_GUNICORN_WORKERS": str(self._gunicorn_workers),
}
if self._sagemaker_port_range is not None:
gunicorn_env["SAGEMAKER_SAFE_PORT_RANGE"] = self._sagemaker_port_range
log.info(f"gunicorn env: {gunicorn_env}")
self._gunicorn_env = gunicorn_env
def _download_scripts(self, bucket, prefix):
log.info("checking boto session region ...")
boto_session = boto3.session.Session()
boto_region = boto_session.region_name
if boto_region in ("us-iso-east-1", "us-gov-west-1"):
raise ValueError("Universal scripts is not supported in us-iso-east-1 or us-gov-west-1")
log.info("downloading universal scripts ...")
client = boto3.client("s3")
resource = boto3.resource("s3")
# download files
paginator = client.get_paginator("list_objects")
for result in paginator.paginate(Bucket=bucket, Delimiter="/", Prefix=prefix):
for file in result.get("Contents", []):
destination = os.path.join(CODE_DIR, file.get("Key").split("/")[-1])
if not os.path.exists(os.path.dirname(destination)):
os.makedirs(os.path.dirname(destination))
resource.meta.client.download_file(bucket, file.get("Key"), destination)
def _create_nginx_tfs_upstream(self):
indentation = " "
tfs_upstream = ""
for port in self._tfs_rest_ports:
tfs_upstream += "{}server localhost:{};\n".format(indentation, port)
tfs_upstream = tfs_upstream[len(indentation) : -2]
return tfs_upstream
def _create_nginx_config(self):
template = self._read_nginx_template()
pattern = re.compile(r"%(\w+)%")
template_values = {
"TFS_VERSION": self._tfs_version,
"TFS_UPSTREAM": self._create_nginx_tfs_upstream(),
"TFS_DEFAULT_MODEL_NAME": self._tfs_default_model_name,
"NGINX_HTTP_PORT": self._nginx_http_port,
"NGINX_LOG_LEVEL": self._nginx_loglevel,
"FORWARD_PING_REQUESTS": GUNICORN_PING if self._use_gunicorn else JS_PING,
"FORWARD_INVOCATION_REQUESTS": GUNICORN_INVOCATIONS
if self._use_gunicorn
else JS_INVOCATIONS,
"PROXY_READ_TIMEOUT": str(self._nginx_proxy_read_timeout_seconds),
}
config = pattern.sub(lambda x: template_values[x.group(1)], template)
log.info("nginx config: \n%s\n", config)
with open("/sagemaker/nginx.conf", "w", encoding="utf8") as f:
f.write(config)
def _read_nginx_template(self):
with open("/sagemaker/nginx.conf.template", "r", encoding="utf8") as f:
template = f.read()
if not template:
raise ValueError("failed to read nginx.conf.template")
return template
def _enable_per_process_gpu_memory_fraction(self):
nvidia_smi_exist = os.path.exists("/usr/bin/nvidia-smi")
if self._tfs_instance_count > 1 and nvidia_smi_exist:
return True
return False
def _get_number_of_gpu_on_host(self):
nvidia_smi_exist = os.path.exists("/usr/bin/nvidia-smi")
if nvidia_smi_exist:
return len(
subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip().split("\n")
)
return 0
def _calculate_per_process_gpu_memory_fraction(self):
return round((1 - self._tfs_gpu_margin) / float(self._tfs_instance_count), 4)
def _start_tfs(self):
self._log_version("tensorflow_model_server --version", "tensorflow version info:")
for i in range(self._tfs_instance_count):
p = self._start_single_tfs(i)
self._tfs.append(p)
def _start_gunicorn(self):
self._log_version("gunicorn --version", "gunicorn version info:")
env = os.environ.copy()
env["TFS_DEFAULT_MODEL_NAME"] = self._tfs_default_model_name
env.update(self._gunicorn_env)
p = subprocess.Popen(self._gunicorn_command.split(), env=env)
log.info("started gunicorn (pid: %d)", p.pid)
self._gunicorn = p
def _start_nginx(self):
self._log_version("/usr/sbin/nginx -V", "nginx version info:")
p = subprocess.Popen("/usr/sbin/nginx -c /sagemaker/nginx.conf".split())
log.info("started nginx (pid: %d)", p.pid)
self._nginx = p
def _log_version(self, command, message):
try:
output = (
subprocess.check_output(command.split(), stderr=subprocess.STDOUT)
.decode("utf-8", "backslashreplace")
.strip()
)
log.info("{}\n{}".format(message, output))
except subprocess.CalledProcessError:
log.warning("failed to run command: %s", command)
def _stop(self, *args): # pylint: disable=W0613
self._state = "stopping"
log.info("stopping services")
try:
os.kill(self._nginx.pid, signal.SIGQUIT)
except OSError:
pass
try:
if self._gunicorn:
os.kill(self._gunicorn.pid, signal.SIGTERM)
except OSError:
pass
try:
for tfs in self._tfs:
os.kill(tfs.pid, signal.SIGTERM)
except OSError:
pass
self._state = "stopped"
log.info("stopped")
def _wait_for_gunicorn(self):
while True:
if os.path.exists("/tmp/gunicorn.sock"):
log.info("gunicorn server is ready!")
return
def _wait_for_tfs(self):
for i in range(self._tfs_instance_count):
tfs_utils.wait_for_model(
self._tfs_rest_ports[i], self._tfs_default_model_name, self._tfs_wait_time_seconds
)
@contextmanager
def _timeout(self, seconds):
def _raise_timeout_error(signum, frame):
raise TimeoutError("time out after {} seconds".format(seconds))
try:
signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(seconds)
yield
finally:
signal.alarm(0)
def _is_tfs_process(self, pid):
for p in self._tfs:
if p.pid == pid:
return True
return False
def _find_tfs_process(self, pid):
for index, p in enumerate(self._tfs):
if p.pid == pid:
return index
return None
def _restart_single_tfs(self, pid):
instance_id = self._find_tfs_process(pid)
if instance_id is None:
raise ValueError("Cannot find tfs with pid: {};".format(pid))
p = self._start_single_tfs(instance_id)
self._tfs[instance_id] = p
def _start_single_tfs(self, instance_id):
cmd = tfs_utils.tfs_command(
self._tfs_grpc_ports[instance_id],
self._tfs_rest_ports[instance_id],
self._tfs_config_path,
self._tfs_enable_batching,
self._tfs_batching_config_path,
tfs_intra_op_parallelism=self._tfs_intra_op_parallelism,
tfs_inter_op_parallelism=self._tfs_inter_op_parallelism,
tfs_enable_gpu_memory_fraction=self._enable_per_process_gpu_memory_fraction(),
tfs_gpu_memory_fraction=self._calculate_per_process_gpu_memory_fraction(),
)
log.info("tensorflow serving command: {}".format(cmd))
num_gpus = self._get_number_of_gpu_on_host()
if num_gpus > 1:
# utilizing multi-gpu
worker_env = os.environ.copy()
worker_env["CUDA_VISIBLE_DEVICES"] = str(instance_id % num_gpus)
p = subprocess.Popen(cmd.split(), env=worker_env)
log.info(
"started tensorflow serving (pid: {}) on GPU: {}".format(
p.pid, instance_id % num_gpus
)
)
else:
# cpu and single gpu
p = subprocess.Popen(cmd.split())
log.info("started tensorflow serving (pid: {})".format(p.pid))
return p
def _monitor(self):
while True:
pid, status = os.wait()
if self._state != "started":
break
if pid == self._nginx.pid:
log.warning("unexpected nginx exit (status: {}). restarting.".format(status))
self._start_nginx()
elif self._is_tfs_process(pid):
log.warning(
"unexpected tensorflow serving exit (status: {}). restarting.".format(status)
)
try:
self._restart_single_tfs(pid)
except (ValueError, OSError) as error:
log.error("Failed to restart tensorflow serving. {}".format(error))
elif self._gunicorn and pid == self._gunicorn.pid:
log.warning("unexpected gunicorn exit (status: {}). restarting.".format(status))
self._start_gunicorn()
def start(self):
log.info("starting services")
self._state = "starting"
signal.signal(signal.SIGTERM, self._stop)
if self._tfs_enable_batching:
log.info("batching is enabled")
tfs_utils.create_batching_config(self._tfs_batching_config_path)
if self._tfs_enable_multi_model_endpoint:
log.info("multi-model endpoint is enabled, TFS model servers will be started later")
else:
self._create_tfs_config()
self._start_tfs()
self._wait_for_tfs()
self._create_nginx_config()
if self._use_gunicorn:
self._setup_gunicorn()
self._start_gunicorn()
# make sure gunicorn is up
with self._timeout(seconds=self._gunicorn_timeout_seconds):
self._wait_for_gunicorn()
self._start_nginx()
self._state = "started"
self._monitor()
self._stop()
if __name__ == "__main__":
ServiceManager().start()