in preprocess/evaluate_hypo.py [0:0]
def filter_qas_dataset_lm_score(args):
qas_files = sorted(list(
glob.glob(os.path.join(args.base_dir, args.sub_dir, args.pattern))))
for qas_file in tqdm(qas_files):
if args.filter_ans_lm_score:
output_file = qas_file + '_filtered{}'.format(args.filter_ans_lm_threshold)
else:
output_file = qas_file + '_filtered'
with open(qas_file, 'r') as qas_f, \
open(output_file, 'w') as output_f:
for line in qas_f:
filtered_qa_dict_list = []
qa_dict_list = json.loads(line.strip())
for qa_dict in qa_dict_list:
filtered_qa_dict = {'context': qa_dict['context'], 'qas': []}
hypo_text_lower = qa_dict['context'].lower()
filtered_list = []
# make sure the question and answer can be extracted, and answer exists in hypo_text
if 'pos_scores' in qa_dict and 'toks' in qa_dict:
for qa, norm_score, unnorm_score, pos_score, tokens in zip(qa_dict['qa'],
qa_dict['norm_scores'],
qa_dict['unnorm_scores'],
qa_dict['pos_scores'],
qa_dict['toks']):
q_a_split = qa.split(' strutConnector')
if len(q_a_split) == 2 and q_a_split[1].lower() in hypo_text_lower:
filtered_list.append((q_a_split[0], q_a_split[1], norm_score, unnorm_score,
pos_score, tokens))
else:
for qa, norm_score, unnorm_score in zip(qa_dict['qa'], qa_dict['norm_scores'],
qa_dict['unnorm_scores']):
q_a_split = qa.split(' strutConnector')
if len(q_a_split) == 2 and q_a_split[1].lower() in hypo_text_lower:
filtered_list.append((q_a_split[0], q_a_split[1], norm_score, unnorm_score))
if not filtered_list:
filtered_qa_dict_list.append(filtered_qa_dict)
continue
if args.filter_ans_lm_score and 'pos_scores' in qa_dict and 'toks' in qa_dict: # filtering qa using answer lm scores:
for t in filtered_list:
ans_found, ans_score_sum = _ans_lm_score(t[4], t[5])
if ans_score_sum >= args.filter_ans_lm_threshold:
filtered_qa_dict['qas'].append({'q': t[0], 'a': t[1], 'ns': t[2], 'uns': t[3],
'pos_s': t[4], 'toks': t[5]})
else:
filtered_list = sorted(filtered_list, key=lambda t: -t[3])
# form a ordered dictionary by going over the filtered list
seen_ans_dict = OrderedDict()
for tmp in filtered_list:
if tmp[1].lower() not in seen_ans_dict:
seen_ans_dict[tmp[1].lower()] = [tmp,]
else:
seen_ans_dict[tmp[1].lower()].append(tmp)
max_qas = 10
keep_adding = True
ans_question_set_dict = {}
while keep_adding and len(filtered_qa_dict['qas']) < max_qas:
keep_adding = False
for key, value in seen_ans_dict.items():
if value:
tmp = value.pop(0)
q, a, ns, uns = tmp[0], tmp[1], tmp[2], tmp[3]
pos_s, toks = None, None
if len(tmp) == 6:
pos_s, toks = tmp[4], tmp[5]
# if the question is repeated, don't add it.
if a.lower() not in ans_question_set_dict:
ans_question_set_dict[a.lower()] = set([q.lower()])
if pos_s is None:
filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns})
else:
filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns,
'pos_s': pos_s, 'toks': toks})
keep_adding = True
elif q.lower() not in ans_question_set_dict[a.lower()]:
ans_question_set_dict[a.lower()].add(q.lower())
if pos_s is None:
filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns})
else:
filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns,
'pos_s': pos_s, 'toks': toks})
keep_adding = True
filtered_qa_dict_list.append(filtered_qa_dict)
json.dump(filtered_qa_dict_list, output_f)
output_f.write('\n')