in src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py [0:0]
def default_model_fn(self, model_dir):
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
In other cases, users should provide customized model_fn() in script.
Args:
model_dir: a directory where model is saved.
Returns: A PyTorch model.
"""
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
.format(DEFAULT_MODEL_FILENAME))
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
try:
return torch.jit.load(model_path, map_location=torch.device('cpu'))
except RuntimeError as e:
raise ModelLoadError(
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
) from e
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)]
if len(model_files) != 1:
raise ValueError(
"Exactly one .pth or .pt file is required for PyTorch models: {}".format(model_files)
)
model_path = os.path.join(model_dir, model_files[0])
try:
model = torch.jit.load(model_path, map_location=device)
except RuntimeError as e:
raise ModelLoadError(
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
) from e
model = model.to(device)
return model