def eval_wordstat()

in parlai/scripts/eval_wordstat.py [0:0]


def eval_wordstat(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    """
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    agent.opt.log()

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    cnt = 0
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    word_statistics = {
        'mean_wlength': [],
        'mean_clength': [],
        'freqs_cnt': Counter(),
        'word_cnt': 0,
        'pred_list': [],
        'pure_pred_list': [],
        'context_list': [],
        'unique_words': set(),
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    def process_prediction(prediction, word_statistics):
        normalized = normalize_answer(prediction)
        word_statistics['pred_list'].append(normalized)
        freqs, _cnt, wlength, clength = get_word_stats(
            prediction, dictionary, bins=bins
        )
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        word_statistics['unique_words'] |= set(normalized.split(" "))
        return word_statistics

    while not world.epoch_done():
        world.parley()
        if batch_size == 1:
            cnt += 1
            prediction = world.acts[-1]['text']
            word_statistics['context_list'].append(world.acts[0]['text'])
            word_statistics['pure_pred_list'].append(prediction)
            word_statistics = process_prediction(prediction, word_statistics)
        else:
            for w in world.worlds:
                try:
                    if 'text' not in w.acts[-1]:
                        continue
                    prediction = w.acts[-1]['text']
                    word_statistics['context_list'].append(w.acts[0]['text'])
                    word_statistics['pure_pred_list'].append(prediction)
                except IndexError:
                    continue
                cnt += 1
                word_statistics = process_prediction(prediction, word_statistics)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(
                report['exs'], min(max_cnt, world.num_examples()), report
            )
            print(text)
            stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
            stat_str += ', '.join(
                [
                    '<{}:{} ({:.{prec}f}%)'.format(
                        b,
                        word_statistics['freqs_cnt'].get(b, 0),
                        (
                            word_statistics['freqs_cnt'].get(b, 0)
                            / word_statistics['word_cnt']
                        )
                        * 100,
                        prec=2,
                    )
                    for b in bins
                ]
            )
            print(
                "Word statistics: {}, avg_word_length: {:.{prec}f}, "
                "avg_char_length: {:.{prec}f}".format(
                    stat_str,
                    numpy.array(word_statistics['mean_wlength']).mean(),
                    numpy.array(word_statistics['mean_clength']).mean(),
                    prec=2,
                )
            )
        if cnt >= max_cnt:
            break
    if world.epoch_done():
        print("EPOCH DONE")

    if opt['compute_unique'] is True:
        unique_list = []
        cntr = Counter(word_statistics['pred_list'])
        for k, v in cntr.items():
            if v == 1:
                unique_list.append(k)
        print(
            "Unique responses: {:.{prec}f}%".format(
                len(unique_list) / len(word_statistics['pred_list']) * 100, prec=2
            )
        )
    print("Total unique tokens:", len(word_statistics['unique_words']))

    if opt['dump_predictions_path'] is not None:
        with PathManager.open(opt['dump_predictions_path'], 'w') as f:
            f.writelines(
                [
                    'CONTEXT: {}\nPREDICTION:{}\n\n'.format(c, p)
                    for c, p in zip(
                        word_statistics['context_list'],
                        word_statistics['pure_pred_list'],
                    )
                ]
            )
        if opt['compute_unique'] is True:
            with PathManager.open(opt['dump_predictions_path'] + '_unique', 'w') as f:
                f.writelines(['{}\n'.format(i) for i in unique_list])

    stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
    stat_str += ', '.join(
        [
            '<{}:{} ({:.{prec}f}%)'.format(
                b,
                word_statistics['freqs_cnt'].get(b, 0),
                (word_statistics['freqs_cnt'].get(b, 0) / word_statistics['word_cnt'])
                * 100,
                prec=2,
            )
            for b in bins
        ]
    )
    print(
        "Word statistics: {}, avg_word_length: {:.{prec}f}, "
        "avg_char_length: {:.{prec}f}".format(
            stat_str,
            numpy.array(word_statistics['mean_wlength']).mean(),
            numpy.array(word_statistics['mean_clength']).mean(),
            prec=2,
        )
    )

    report = world.report()
    print(report)
    return report