def filter_qas_dataset_lm_score()

in preprocess/evaluate_hypo.py [0:0]


def filter_qas_dataset_lm_score(args):
    qas_files = sorted(list(
        glob.glob(os.path.join(args.base_dir, args.sub_dir, args.pattern))))
    for qas_file in tqdm(qas_files):
        if args.filter_ans_lm_score:
            output_file = qas_file + '_filtered{}'.format(args.filter_ans_lm_threshold)
        else:
            output_file = qas_file + '_filtered'
        with open(qas_file, 'r') as qas_f, \
             open(output_file, 'w') as output_f:
            for line in qas_f:
                filtered_qa_dict_list = []
                qa_dict_list = json.loads(line.strip())
                for qa_dict in qa_dict_list:
                    filtered_qa_dict = {'context': qa_dict['context'], 'qas': []}
                    hypo_text_lower = qa_dict['context'].lower()
                    filtered_list = []
                    # make sure the question and answer can be extracted, and answer exists in hypo_text
                    if 'pos_scores' in qa_dict and 'toks' in qa_dict:
                        for qa, norm_score, unnorm_score, pos_score, tokens in zip(qa_dict['qa'],
                                                                                   qa_dict['norm_scores'],
                                                                                   qa_dict['unnorm_scores'],
                                                                                   qa_dict['pos_scores'],
                                                                                   qa_dict['toks']):
                            q_a_split = qa.split(' strutConnector')
                            if len(q_a_split) == 2 and q_a_split[1].lower() in hypo_text_lower:
                                filtered_list.append((q_a_split[0], q_a_split[1], norm_score, unnorm_score,
                                                      pos_score, tokens))
                    else:
                        for qa, norm_score, unnorm_score in zip(qa_dict['qa'], qa_dict['norm_scores'],
                                                                qa_dict['unnorm_scores']):
                            q_a_split = qa.split(' strutConnector')
                            if len(q_a_split) == 2 and q_a_split[1].lower() in hypo_text_lower:
                                filtered_list.append((q_a_split[0], q_a_split[1], norm_score, unnorm_score))
                    if not filtered_list:
                        filtered_qa_dict_list.append(filtered_qa_dict)
                        continue
                    if args.filter_ans_lm_score and 'pos_scores' in qa_dict and 'toks' in qa_dict: # filtering qa using answer lm scores:
                        for t in filtered_list:
                            ans_found, ans_score_sum = _ans_lm_score(t[4], t[5])
                            if  ans_score_sum >= args.filter_ans_lm_threshold:
                                filtered_qa_dict['qas'].append({'q': t[0], 'a': t[1], 'ns': t[2], 'uns': t[3],
                                                                'pos_s': t[4], 'toks': t[5]})
                    else:
                        filtered_list = sorted(filtered_list, key=lambda t: -t[3])
                        # form a ordered dictionary by going over the filtered list
                        seen_ans_dict = OrderedDict()
                        for tmp in filtered_list:
                            if tmp[1].lower() not in seen_ans_dict:
                                seen_ans_dict[tmp[1].lower()] = [tmp,]
                            else:
                                seen_ans_dict[tmp[1].lower()].append(tmp)
                        max_qas = 10
                        keep_adding = True
                        ans_question_set_dict = {}
                        while keep_adding and len(filtered_qa_dict['qas']) < max_qas:
                            keep_adding = False
                            for key, value in seen_ans_dict.items():
                                if value:
                                    tmp = value.pop(0)
                                    q, a, ns, uns = tmp[0], tmp[1], tmp[2], tmp[3]
                                    pos_s, toks = None, None
                                    if len(tmp) == 6:
                                        pos_s, toks = tmp[4], tmp[5]

                                    # if the question is repeated, don't add it.
                                    if a.lower() not in ans_question_set_dict:
                                        ans_question_set_dict[a.lower()] = set([q.lower()])
                                        if pos_s is None:
                                            filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns})
                                        else:
                                            filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns,
                                                                            'pos_s': pos_s, 'toks': toks})

                                        keep_adding = True
                                    elif q.lower() not in ans_question_set_dict[a.lower()]:
                                        ans_question_set_dict[a.lower()].add(q.lower())
                                        if pos_s is None:
                                            filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns})
                                        else:
                                            filtered_qa_dict['qas'].append({'q': q, 'a': a, 'ns': ns, 'uns': uns,
                                                                            'pos_s': pos_s, 'toks': toks})
                                        keep_adding = True
                    filtered_qa_dict_list.append(filtered_qa_dict)
                json.dump(filtered_qa_dict_list, output_f)
                output_f.write('\n')