in longform-qa/lfqa_utils.py [0:0]
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
qar_scheduler = get_linear_schedule_with_warmup(
qar_optimizer,
num_warmup_steps=100,
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
)
for e in range(qar_args.num_epochs):
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
m_save_dict = {
"model": qar_model.state_dict(),
"optimizer": qar_optimizer.state_dict(),
"scheduler": qar_scheduler.state_dict(),
}
print("Saving model {}".format(qar_args.model_save_name))
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))