in src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py [0:0]
def default_input_fn(self, input_data, content_type):
"""A default input_fn that can handle JSON, CSV and NPZ formats.
Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor,
depending if cuda is available.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np_array = decoder.decode(input_data, content_type)
tensor = torch.FloatTensor(
np_array) if content_type in content_types.UTF8_TYPES else torch.from_numpy(np_array)
return tensor.to(device)