def predict_fn()

in sagemaker/src/hf_train_deploy.py [0:0]


def predict_fn(input_data, model):
    """Model prediction for a single input"""
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    sm = torch.nn.Softmax(dim=1)
    input_data = input_data.to(device)
    with torch.no_grad():
        output = model(**input_data)
        
        output = sm(output['logits'])
        y = output.detach().numpy()[0]

    return y