def predict()

in notebooks/modelscript_pytorch.py [0:0]


def predict(model, payload):
    
    if type(payload) == list:
        data = np.frombuffer(payload[0]['body'],dtype=np.float32).reshape(1,1,28,28)
    elif type(payload) == np.ndarray:
        data = payload  
    try:
        print(type(data))
        input_data = torch.Tensor(data)
        model.eval()
        with torch.no_grad():
            out =  model(input_data.to(device)).argmax(axis=1)[0].tolist()
    except Exception as e:
        out = str(e)
    return [out]