def train_qa_retriever()

in longform-qa/lfqa_utils.py [0:0]


def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
    qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
    qar_scheduler = get_linear_schedule_with_warmup(
        qar_optimizer,
        num_warmup_steps=100,
        num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
    )
    for e in range(qar_args.num_epochs):
        train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
        m_save_dict = {
            "model": qar_model.state_dict(),
            "optimizer": qar_optimizer.state_dict(),
            "scheduler": qar_scheduler.state_dict(),
        }
        print("Saving model {}".format(qar_args.model_save_name))
        torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
        eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
        print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))