def select_unlikelihood_hypos_lm_score()

in preprocess/evaluate_hypo.py [0:0]


def select_unlikelihood_hypos_lm_score(args):
    assert os.path.exists(os.path.join(args.base_dir, args.sub_dir))

    def _compute_lm_for_hypos(hypo_list, metrics):
        return_hypo_ind_avg = {}
        return_hypo_ind_min = {}
        return_hypo_avg = {}
        return_hypo_min = {}
        prob_metric = 'eval_prob'
        return_hypo_ind_avg[prob_metric] = 0
        if args.select_highest:
            return_hypo_avg[prob_metric] = -1e5
        else:
            return_hypo_avg[prob_metric] = 1e5

        ans_lm_metric = 'ans_lm'
        return_hypo_ind_avg[ans_lm_metric] = 0
        if args.select_highest:
            return_hypo_avg[ans_lm_metric] = -1e5
        else:
            return_hypo_avg[ans_lm_metric] = 1e5
        for i, hypo in enumerate(hypo_list):
            sum_ans_lm_per_hypo = 0.0
            count_per_hypo = 0
            for qa in hypo['qas']:
                ans_found = False
                if 'pos_s' in qa:
                    ans_found, value = _ans_lm_score(np.array(qa['eval_pos_scores']) - np.array(qa['pos_s']), qa['toks'])
                if ans_found:
                    sum_ans_lm_per_hypo += value
                    count_per_hypo += 1
            if count_per_hypo > 0:
                avg_per_hypo = sum_ans_lm_per_hypo / count_per_hypo
                if not args.select_highest and avg_per_hypo < return_hypo_avg[ans_lm_metric]:
                    return_hypo_ind_avg[ans_lm_metric] = i
                    return_hypo_avg[ans_lm_metric] = avg_per_hypo
                if args.select_highest and avg_per_hypo > return_hypo_avg[ans_lm_metric]:
                    return_hypo_ind_avg[ans_lm_metric] = i
                    return_hypo_avg[ans_lm_metric] = avg_per_hypo

        for metric in metrics:
            return_hypo_ind_avg[metric] = 0
            return_hypo_ind_min[metric] = 0
            if args.select_highest:
                return_hypo_avg[metric] = -1e5
                # return_hypo_min[metric] = -1e5
            else:
                return_hypo_avg[metric] = 1e5
                # return_hypo_min[metric] = 1e5
            # min_avg_value = 1e5
            # min_min_value = 1e5
            for i, hypo in enumerate(hypo_list):
                sum_prob_per_hypo = 0.0
                sum_per_hypo = 0.0
                count_per_hypo = 0
                min_per_hypo = 1e5
                for qa in hypo['qas']:
                    metric_split = metric.split('-')
                    if len(metric_split) == 1:
                        value = qa[metric]
                    else:
                        value = qa[metric_split[0]] - qa[metric_split[1]]
                    if metric == 'eval_ns':
                        sum_prob_per_hypo += np.exp(value)
                    sum_per_hypo += value
                    if min_per_hypo > value:
                        min_per_hypo = value
                    count_per_hypo += 1
                if count_per_hypo > 0:
                    if metric == 'eval_ns':
                        avg_prob_per_hypo = sum_prob_per_hypo / count_per_hypo
                        if not args.select_highest and avg_prob_per_hypo < return_hypo_avg[prob_metric]:
                            return_hypo_avg[prob_metric] = avg_prob_per_hypo
                            return_hypo_ind_avg[prob_metric] = i
                        if args.select_highest and avg_prob_per_hypo > return_hypo_avg[prob_metric]:
                            return_hypo_avg[prob_metric] = avg_prob_per_hypo
                            return_hypo_ind_avg[prob_metric] = i

                    avg_per_hypo = sum_per_hypo / count_per_hypo
                    if not args.select_highest and avg_per_hypo < return_hypo_avg[metric]:
                        return_hypo_ind_avg[metric] = i
                        return_hypo_avg[metric] = avg_per_hypo
                    if args.select_highest and avg_per_hypo > return_hypo_avg[metric]:
                        return_hypo_ind_avg[metric] = i
                        return_hypo_avg[metric] = avg_per_hypo
                    # if min_per_hypo < return_hypo_min[metric]:
                    #     return_hypo_ind_min[metric] = i
                    #     return_hypo_min[metric] = min_per_hypo
        return return_hypo_ind_avg, return_hypo_ind_min, return_hypo_avg, return_hypo_min

    lm_files = sorted(list(
        glob.glob(os.path.join(args.base_dir, args.sub_dir, args.pattern))))
    metrics = ['eval_ns', 'eval_a_ns', 'eval_uns', 'eval_a_uns',
               'eval_ns-ns', 'eval_a_ns-ns', 'eval_uns-uns', 'eval_a_uns-uns']

    output_index_file = os.path.join(args.base_dir, args.sub_dir, 'untarget.index')
    avg_avg_scores = {}
    avg_min_scores = {}
    for metric in metrics:
        avg_avg_scores[metric] = 0.0
        avg_min_scores[metric] = 0.0
    avg_avg_scores['eval_prob'] = 0.0
    avg_avg_scores['ans_lm'] = 0.0
    count = 0
    # output_score_file = os.path.join(args.base_dir, args.sub_dir, args.split + '.qags')
    with open(output_index_file, 'w') as output_index_f:
        for lm_file in tqdm(lm_files):
            with open(lm_file, 'r') as lm_f:
                for line in lm_f:
                    hypo_list = json.loads(line.strip())
                    bad_hypo_ind_avg, bad_hypo_ind_min, min_avg_value, min_min_value = \
                        _compute_lm_for_hypos(hypo_list, metrics)
                    # print("min_avg_value:")
                    # print(min_avg_value)
                    # print("min_min_value")
                    # print(min_min_value)
                    if all([v != 1e5 for v in min_avg_value.values()]) and \
                            all([v != -1e5 for v in min_avg_value.values()]):
                        # all([v != 1e5 for v in min_min_value.values()]) and \
                        for metric in metrics:
                            avg_avg_scores[metric] += min_avg_value[metric]
                            # avg_min_scores[metric] += min_min_value[metric]
                        avg_avg_scores['eval_prob'] += min_avg_value['eval_prob']
                        avg_avg_scores['ans_lm'] += min_avg_value['ans_lm']
                        count += 1
                    json.dump({'avg': bad_hypo_ind_avg, 'min': bad_hypo_ind_min,
                               'avg_value': min_avg_value, 'min_value': min_min_value}, output_index_f)
                    output_index_f.write('\n')
    avg_out_text = ""
    # min_out_text = ""
    print("count = {}".format(count))
    for metric in metrics:
        if count > 0:
            avg_avg_scores[metric] /= count
            # avg_min_scores[metric] /= count
        avg_out_text += f"{avg_avg_scores[metric]} "
        # min_out_text += f"{avg_min_scores[metric]} "
        # avg_out_text += f"{metric}={avg_avg_scores[metric]} "
        # min_out_text += f"{metric}={avg_min_scores[metric]} "
    avg_out_text += f"{avg_avg_scores['eval_prob'] / count if count > 0 else 0.0} "
    avg_out_text += f"{avg_avg_scores['ans_lm'] / count if count > 0 else 0.0}"
    print("select_unlikelihood_hypos_lm_score Done! Index written to {}".format(output_index_file))
    print("avg lm scores: " + avg_out_text)
    if count == 0:
        print("Avg lm scores are not computed, probably because the return_token_scores option was set to False in running qa_gen in preprocess/sm_inference_asum.py")