def save_model()

in pytorch_alternatives/custom_pytorch_nlp/src/main.py [0:0]


def save_model(model, model_dir, max_seq_len):
    path = os.path.join(model_dir, "model.pth")
    x = torch.randint(0, 10, (1, max_seq_len))
    model = model.cpu()
    model.eval()
    m = torch.jit.trace(model, x)
    torch.jit.save(m, path)