def make_qa_s2s_model()

in longform-qa/lfqa_utils.py [0:0]


def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    if from_file is not None:
        param_dict = torch.load(from_file)  # has model weights, optimizer, and scheduler states
        model.load_state_dict(param_dict["model"])
    return tokenizer, model