in code/inference.py [0:0]
def model_fn(model_dir): device = get_device() print('device is') print(device) model = torch.load(model_dir + '/model.pth', map_location=torch.device(device)) print(type(model)) return model