def model_fn()

in code/inference.py [0:0]


def model_fn(model_dir):
    logger.info('model_fn')
    print('Loading the trained model...')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ProteinClassifier(10) # pass number of classes, in our case its 10
    with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f, map_location=device))
    return model.to(device)