def model_fn()

in code/uncompiled-inference.py [0:0]


def model_fn(model_dir):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.jit.load('model.pth', map_location=device)
    return model