in longform-qa/lfqa_utils.py [0:0]
def evaluate_qa_retriever(model, dataset, tokenizer, args):
model.eval()
# make iterator
eval_sampler = SequentialSampler(dataset)
model_collate_fn = functools.partial(
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
tot_loss = 0.0
with torch.no_grad():
for step, batch in enumerate(epoch_iterator):
q_ids, q_mask, a_ids, a_mask = batch
loss = model(q_ids, q_mask, a_ids, a_mask)
tot_loss += loss.item()
return tot_loss / (step + 1)