def model_fn()

in code/deploy_ei.py [0:0]


def model_fn(model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("================== objects in model_dir =====================")
    print(os.listdir(model_dir))
    loaded_model = torch.jit.load(os.path.join(model_dir, "traced_bert.pt"))
    print("================== model loaded =============================")
    
    return loaded_model.to(device)