in parlai/scripts/eval_wordstat.py [0:0]
def eval_wordstat(opt):
"""
Evaluates a model.
:param opt: tells the evaluation function how to run
"""
random.seed(42)
# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
world = create_task(opt, agent)
agent.opt.log()
if opt.get('external_dict'):
print('[ Using external dictionary from: {} ]'.format(opt['external_dict']))
dict_opt = copy.deepcopy(opt)
dict_opt['dict_file'] = opt['external_dict']
dictionary = DictionaryAgent(dict_opt)
else:
print('[ Using model bundled dictionary ]')
dictionary = agent.dict
batch_size = opt['batchsize']
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
log_every_n_secs = float('inf')
log_time = TimeLogger()
cnt = 0
max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
word_statistics = {
'mean_wlength': [],
'mean_clength': [],
'freqs_cnt': Counter(),
'word_cnt': 0,
'pred_list': [],
'pure_pred_list': [],
'context_list': [],
'unique_words': set(),
}
bins = [int(i) for i in opt['freq_bins'].split(',')]
def process_prediction(prediction, word_statistics):
normalized = normalize_answer(prediction)
word_statistics['pred_list'].append(normalized)
freqs, _cnt, wlength, clength = get_word_stats(
prediction, dictionary, bins=bins
)
word_statistics['word_cnt'] += _cnt
word_statistics['mean_wlength'].append(wlength)
word_statistics['mean_clength'].append(clength)
word_statistics['freqs_cnt'] += Counter(freqs)
word_statistics['unique_words'] |= set(normalized.split(" "))
return word_statistics
while not world.epoch_done():
world.parley()
if batch_size == 1:
cnt += 1
prediction = world.acts[-1]['text']
word_statistics['context_list'].append(world.acts[0]['text'])
word_statistics['pure_pred_list'].append(prediction)
word_statistics = process_prediction(prediction, word_statistics)
else:
for w in world.worlds:
try:
if 'text' not in w.acts[-1]:
continue
prediction = w.acts[-1]['text']
word_statistics['context_list'].append(w.acts[0]['text'])
word_statistics['pure_pred_list'].append(prediction)
except IndexError:
continue
cnt += 1
word_statistics = process_prediction(prediction, word_statistics)
if log_time.time() > log_every_n_secs:
report = world.report()
text, report = log_time.log(
report['exs'], min(max_cnt, world.num_examples()), report
)
print(text)
stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
stat_str += ', '.join(
[
'<{}:{} ({:.{prec}f}%)'.format(
b,
word_statistics['freqs_cnt'].get(b, 0),
(
word_statistics['freqs_cnt'].get(b, 0)
/ word_statistics['word_cnt']
)
* 100,
prec=2,
)
for b in bins
]
)
print(
"Word statistics: {}, avg_word_length: {:.{prec}f}, "
"avg_char_length: {:.{prec}f}".format(
stat_str,
numpy.array(word_statistics['mean_wlength']).mean(),
numpy.array(word_statistics['mean_clength']).mean(),
prec=2,
)
)
if cnt >= max_cnt:
break
if world.epoch_done():
print("EPOCH DONE")
if opt['compute_unique'] is True:
unique_list = []
cntr = Counter(word_statistics['pred_list'])
for k, v in cntr.items():
if v == 1:
unique_list.append(k)
print(
"Unique responses: {:.{prec}f}%".format(
len(unique_list) / len(word_statistics['pred_list']) * 100, prec=2
)
)
print("Total unique tokens:", len(word_statistics['unique_words']))
if opt['dump_predictions_path'] is not None:
with PathManager.open(opt['dump_predictions_path'], 'w') as f:
f.writelines(
[
'CONTEXT: {}\nPREDICTION:{}\n\n'.format(c, p)
for c, p in zip(
word_statistics['context_list'],
word_statistics['pure_pred_list'],
)
]
)
if opt['compute_unique'] is True:
with PathManager.open(opt['dump_predictions_path'] + '_unique', 'w') as f:
f.writelines(['{}\n'.format(i) for i in unique_list])
stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
stat_str += ', '.join(
[
'<{}:{} ({:.{prec}f}%)'.format(
b,
word_statistics['freqs_cnt'].get(b, 0),
(word_statistics['freqs_cnt'].get(b, 0) / word_statistics['word_cnt'])
* 100,
prec=2,
)
for b in bins
]
)
print(
"Word statistics: {}, avg_word_length: {:.{prec}f}, "
"avg_char_length: {:.{prec}f}".format(
stat_str,
numpy.array(word_statistics['mean_wlength']).mean(),
numpy.array(word_statistics['mean_clength']).mean(),
prec=2,
)
)
report = world.report()
print(report)
return report