grade_school_math/sample.py (17 lines of code) (raw):

import torch as th from dataset import get_examples, GSMDataset from calculator import sample from transformers import GPT2Tokenizer, GPT2LMHeadModel def main(): device = th.device("cuda") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("model_ckpts") model.to(device) print("Model Loaded") test_examples = get_examples("test") qn = test_examples[1]["question"] sample_len = 100 print(qn.strip()) print(sample(model, qn, tokenizer, device, sample_len)) if __name__ == "__main__": main()