in notebooks/modelscript_pytorch.py [0:0]
def predict(model, payload):
if type(payload) == list:
data = np.frombuffer(payload[0]['body'],dtype=np.float32).reshape(1,1,28,28)
elif type(payload) == np.ndarray:
data = payload
try:
print(type(data))
input_data = torch.Tensor(data)
model.eval()
with torch.no_grad():
out = model(input_data.to(device)).argmax(axis=1)[0].tolist()
except Exception as e:
out = str(e)
return [out]