pytorch/inference/docker/build_artifacts/default_inference_handler.py (55 lines of code) (raw):
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import textwrap
import torch, torcheia
from sagemaker_inference import (
content_types,
decoder,
default_inference_handler,
encoder,
errors,
utils,
)
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
DEFAULT_MODEL_FILENAME = "model.pt"
torch._C._jit_set_profiling_executor(False)
device = torch.device("cpu")
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)
def default_model_fn(self, model_dir):
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
In other cases, users should provide customized model_fn() in script.
Args:
model_dir: a directory where model is saved.
Returns: A PyTorch model.
"""
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
raise FileNotFoundError(
"Failed to load model with default model_fn: missing file {}.".format(
DEFAULT_MODEL_FILENAME
)
)
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
model = torch.jit.load(model_path, map_location=device)
model.eval()
model = model.to(device)
# attach_eia() is introduced in PyTorch Elastic Inference 1.5.1
model = torcheia.jit.attach_eia(model, 0)
return model
def default_input_fn(self, input_data, content_type):
"""A default input_fn that can handle JSON, CSV and NPZ formats.
Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor,
depending if cuda is available.
"""
np_array = decoder.decode(input_data, content_type)
tensor = (
torch.FloatTensor(np_array)
if content_type in content_types.UTF8_TYPES
else torch.from_numpy(np_array)
)
return tensor.to(device)
def default_predict_fn(self, data, model):
"""A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
Runs prediction on GPU if cuda is available.
Args:
data: input data (torch.Tensor) for prediction deserialized by input_fn
model: PyTorch model loaded in memory by model_fn
Returns: a prediction
"""
input_data = data.to(device)
with torch.no_grad():
with torch.jit.optimized_execution(True):
output = model.forward(input_data)
return output
def default_output_fn(self, prediction, accept):
"""A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.
Args:
prediction: a prediction result from predict_fn
accept: type which the output data needs to be serialized
Returns: output data serialized
"""
if type(prediction) == torch.Tensor:
prediction = prediction.detach().cpu().numpy().tolist()
for content_type in utils.parse_accept(accept):
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
encoded_prediction = encoder.encode(prediction, content_type)
if content_type == content_types.CSV:
encoded_prediction = encoded_prediction.encode("utf-8")
return encoded_prediction
raise errors.UnsupportedFormatError(accept)