in src/sagemaker_huggingface_inference_toolkit/handler_service.py [0:0]
def predict(self, data, model, context=None):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.
Args:
data (dict): deserialized decoded_input_data returned by the input_fn
model : Model returned by the `load` method or if it is a custom module `model_fn`.
context (obj): metadata on the incoming request data (default: None).
Returns:
obj (dict): prediction result.
"""
# pop inputs for pipeline
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = model(inputs, **parameters)
else:
prediction = model(inputs)
return prediction