in longform-qa/lfqa_utils.py [0:0]
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if from_file is not None:
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
model.load_state_dict(param_dict["model"])
return tokenizer, model