tensorflow/inference/docker/build_artifacts/sagemaker_neuron/tfs_utils.py (221 lines of code) (raw):
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging
import multiprocessing
import os
import re
import requests
import time
import json
from multi_model_utils import timeout
from urllib3.util.retry import Retry
from urllib3.exceptions import NewConnectionError, MaxRetryError
from collections import namedtuple
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
DEFAULT_CONTENT_TYPE = "application/json"
DEFAULT_ACCEPT_HEADER = "application/json"
CUSTOM_ATTRIBUTES_HEADER = "X-Amzn-SageMaker-Custom-Attributes"
Context = namedtuple(
"Context",
"model_name, model_version, method, rest_uri, grpc_port, channel, "
"custom_attributes, request_content_type, accept_header, content_length",
)
def parse_request(req, rest_port, grpc_port, default_model_name, model_name=None, channel=None):
tfs_attributes = parse_tfs_custom_attributes(req)
tfs_uri = make_tfs_uri(rest_port, tfs_attributes, default_model_name, model_name)
if not model_name:
model_name = tfs_attributes.get("tfs-model-name")
context = Context(
model_name,
tfs_attributes.get("tfs-model-version"),
tfs_attributes.get("tfs-method"),
tfs_uri,
grpc_port,
channel,
req.get_header(CUSTOM_ATTRIBUTES_HEADER),
req.get_header("Content-Type") or DEFAULT_CONTENT_TYPE,
req.get_header("Accept") or DEFAULT_ACCEPT_HEADER,
req.content_length,
)
data = req.stream
return data, context
def make_tfs_uri(port, attributes, default_model_name, model_name=None):
log.info("sagemaker tfs attributes: \n{}".format(attributes))
tfs_model_name = model_name or attributes.get("tfs-model-name", default_model_name)
tfs_model_version = attributes.get("tfs-model-version")
tfs_method = attributes.get("tfs-method", "predict")
uri = "http://localhost:{}/v1/models/{}".format(port, tfs_model_name)
if tfs_model_version:
uri += "/versions/" + tfs_model_version
uri += ":" + tfs_method
return uri
def parse_tfs_custom_attributes(req):
attributes = {}
header = req.get_header(CUSTOM_ATTRIBUTES_HEADER)
if header:
matches = re.findall(r"(tfs-[a-z\-]+=[^,]+)", header)
attributes = dict(attribute.split("=") for attribute in matches)
return attributes
def create_tfs_config_individual_model(model_name, base_path):
config = "model_config_list: {\n"
config += " config: {\n"
config += " name: '{}'\n".format(model_name)
config += " base_path: '{}'\n".format(base_path)
config += " model_platform: 'tensorflow'\n"
config += " model_version_policy: {\n"
config += " specific: {\n"
for version in find_model_versions(base_path):
config += " versions: {}\n".format(version)
config += " }\n"
config += " }\n"
config += " }\n"
config += "}\n"
return config
def tfs_command(
tfs_grpc_port,
tfs_rest_port,
tfs_config_path,
tfs_enable_batching,
tfs_batching_config_file,
tfs_intra_op_parallelism=None,
tfs_inter_op_parallelism=None,
tfs_enable_gpu_memory_fraction=False,
tfs_gpu_memory_fraction=None,
):
cmd = (
"tensorflow_model_server_neuron "
"--port={} "
"--rest_api_port={} "
"--model_config_file={} "
"--max_num_load_retries=0 {} {} {} {}".format(
tfs_grpc_port,
tfs_rest_port,
tfs_config_path,
get_tfs_batching_args(tfs_enable_batching, tfs_batching_config_file),
get_tensorflow_intra_op_parallelism_args(tfs_intra_op_parallelism),
get_tensorflow_inter_op_parallelism_args(tfs_inter_op_parallelism),
get_tfs_gpu_mem_args(tfs_enable_gpu_memory_fraction, tfs_gpu_memory_fraction),
)
)
return cmd
def find_models():
base_path = "/opt/ml/model"
models = []
for f in _find_saved_model_files(base_path):
parts = f.split("/")
if len(parts) >= 6 and re.match(r"^\d+$", parts[-2]):
model_path = "/".join(parts[0:-2])
if model_path not in models:
models.append(model_path)
return models
def find_model_versions(model_path):
"""Remove leading zeros from the version number, returns list of versions"""
return [
version[:-1].lstrip("0") + version[-1]
for version in os.listdir(model_path)
if version.isnumeric()
]
def _find_saved_model_files(path):
for e in os.scandir(path):
if e.is_dir():
yield from _find_saved_model_files(os.path.join(path, e.name))
else:
if e.name == "saved_model.pb":
yield os.path.join(path, e.name)
def get_tfs_batching_args(enable_batching, tfs_batching_config):
if enable_batching:
return "--enable_batching=true " "--batching_parameters_file={}".format(tfs_batching_config)
else:
return ""
def get_tensorflow_intra_op_parallelism_args(tfs_intra_op_parallelism):
if tfs_intra_op_parallelism:
return "--tensorflow_intra_op_parallelism={}".format(tfs_intra_op_parallelism)
else:
return ""
def get_tensorflow_inter_op_parallelism_args(tfs_inter_op_parallelism):
if tfs_inter_op_parallelism:
return "--tensorflow_inter_op_parallelism={}".format(tfs_inter_op_parallelism)
else:
return ""
def get_tfs_gpu_mem_args(enable_gpu_memory_fraction, gpu_memory_fraction):
if enable_gpu_memory_fraction and gpu_memory_fraction:
return "--per_process_gpu_memory_fraction={}".format(gpu_memory_fraction)
else:
return ""
def create_batching_config(batching_config_file):
class _BatchingParameter:
def __init__(self, key, env_var, value, defaulted_message):
self.key = key
self.env_var = env_var
self.value = value
self.defaulted_message = defaulted_message
cpu_count = multiprocessing.cpu_count()
batching_parameters = [
_BatchingParameter(
"max_batch_size",
"SAGEMAKER_TFS_MAX_BATCH_SIZE",
8,
"max_batch_size defaulted to {}. Set {} to override default. "
"Tuning this parameter may yield better performance.",
),
_BatchingParameter(
"batch_timeout_micros",
"SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS",
1000,
"batch_timeout_micros defaulted to {}. Set {} to override "
"default. Tuning this parameter may yield better performance.",
),
_BatchingParameter(
"num_batch_threads",
"SAGEMAKER_TFS_NUM_BATCH_THREADS",
cpu_count,
"num_batch_threads defaulted to {}," "the number of CPUs. Set {} to override default.",
),
_BatchingParameter(
"max_enqueued_batches",
"SAGEMAKER_TFS_MAX_ENQUEUED_BATCHES",
# Batch limits number of concurrent requests, which limits number
# of enqueued batches, so this can be set high for Batch
100000000 if "SAGEMAKER_BATCH" in os.environ else cpu_count,
"max_enqueued_batches defaulted to {}. Set {} to override default. "
"Tuning this parameter may be necessary to tune out-of-memory "
"errors occur.",
),
]
warning_message = ""
for batching_parameter in batching_parameters:
if batching_parameter.env_var in os.environ:
batching_parameter.value = os.environ[batching_parameter.env_var]
else:
warning_message += batching_parameter.defaulted_message.format(
batching_parameter.value, batching_parameter.env_var
)
warning_message += "\n"
if warning_message:
log.warning(warning_message)
config = ""
for batching_parameter in batching_parameters:
config += "%s { value: %s }\n" % (batching_parameter.key, batching_parameter.value)
log.info("batching config: \n%s\n", config)
with open(batching_config_file, "w", encoding="utf8") as f:
f.write(config)
def wait_for_model(rest_port, model_name, timeout_seconds, wait_interval_seconds=5):
tfs_url = "http://localhost:{}/v1/models/{}".format(rest_port, model_name)
with timeout(timeout_seconds):
while True:
try:
session = requests.Session()
retries = Retry(total=9, backoff_factor=0.1)
session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retries))
log.info("Trying to connect with model server: {}".format(tfs_url))
response = session.get(tfs_url)
log.info(response)
if response.status_code == 200:
versions = json.loads(response.content)["model_version_status"]
if all(version["state"] == "AVAILABLE" for version in versions):
break
except (
ConnectionRefusedError,
NewConnectionError,
MaxRetryError,
requests.exceptions.ConnectionError,
):
log.warning("model: {} is not available yet ".format(tfs_url))
time.sleep(wait_interval_seconds)
log.info("model: {} is available now".format(tfs_url))