def model_fn()

in code/inference.py [0:0]


def model_fn(model_dir):
    device = get_device()
    print('device is')
    print(device)
    model = torch.load(model_dir + '/model.pth', map_location=torch.device(device))
    print(type(model))
    return model