def predict_fn()

in code/inference.py [0:0]


def predict_fn(data, model):
    print('in custom predict function')
    with torch.no_grad():
        device = get_device()
        model = model.to(device)
        input_data = data.to(device)
        model.eval()
        output = model(input_data)
        
    return output