in code/inference.py [0:0]
def predict_fn(data, model):
print('in custom predict function')
with torch.no_grad():
device = get_device()
model = model.to(device)
input_data = data.to(device)
model.eval()
output = model(input_data)
return output