in dpr_scale/eval_dpr.py [0:0]
def evaluate_retrieval(retrieval_file, topk, regex=False, oufname=''):
tokenizer = SimpleTokenizer()
retrieval = json.load(open(retrieval_file))
accuracy = { k : [] for k in topk }
max_k = max(topk)
for question in tqdm(retrieval):
answers = question['answers']
contexts = question['ctxs']
has_ans_idx = max_k # first index in contexts that has answers
for idx, ctx in enumerate(contexts):
if idx >= max_k:
break
text = ctx['text'] # .split('\n')[1] # [0] is title, [1] is text
if has_answers(text, answers, tokenizer, regex):
has_ans_idx = min(has_ans_idx, idx)
if oufname:
ctx['has_answer'] = True
else:
# don't skip if outputing eval results
break
elif oufname:
ctx['has_answer'] = False
for k in topk:
accuracy[k].append(0 if has_ans_idx >= k else 1)
print("Evaluating", retrieval_file)
for k in topk:
print(f'Top{k}\taccuracy: {np.mean(accuracy[k])}')
if oufname:
with open(oufname, 'w') as ouf:
json.dump(retrieval, ouf, indent=4)
return accuracy