def test_refactor()

in opacus_lab/models/GPT2/refactor.py [0:0]


def test_refactor(pretrained, refactored):
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    string = torch.tensor(tokenizer.encode("this is a test"))
    pretrained = pretrained.eval()
    refactored = refactored.eval()
    X = pretrained(string)
    Y = refactored(string)
    return Y.equal(X.logits)