def model_fn()

in pytorch_code/classifier/classifier.py [0:0]


def model_fn(model_dir):
    print('model_fn')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if torch.cuda.device_count() > 1:
        print("Gpu count: {}".format(torch.cuda.device_count()))

    copyfile(os.path.join(model_dir, 'model.pkl'), '/tmp/export.pkl')
    learn = load_learner(path='/tmp')

    return learn