in retrieval_eval_bleu.py [0:0]
def get_bleu4(split, history_len=1):
"""
Print BLEU scores and output contexts and retrieved responses.
"""
if history_len < 1:
history_len = 1
source_ct = [0, 0, 0]
net_parlai_dict = ParlAIDictionary.create_from_reddit_style(net_dictionary)
bleu_parlai_dict = ParlAIDictionary.create_from_reddit_style(bleu_dictionary)
scorer = bleu.Scorer(BLEU_PAD_IDX, BLEU_EOS_IDX, BLEU_UNK_IDX)
outf = open("retrieved_split_" + args.name + "_" + split + ".txt", "w")
def _get_dataset(reddit_dict, parlai_dict):
if args.task == "dailydialog":
return DDDataset(
split,
parlai_dict,
data_folder=args.dailydialog_folder,
history_len=history_len,
)
elif args.task == "empchat":
return EmpDataset(
split,
parlai_dict,
data_folder=args.empchat_folder,
history_len=history_len,
reactonly=args.reactonly,
fasttext=args.fasttext,
fasttext_type=args.fasttext_type,
fasttext_path=args.fasttext_path,
)
elif args.task == "reddit":
return RedditDataset(
data_folder=args.reddit_folder,
chunk_id=999,
dict_=reddit_dict,
max_hist_len=history_len,
rm_blank_sentences=True,
)
else:
raise ValueError("Task unrecognized!")
net_dataset = _get_dataset(net_dictionary, net_parlai_dict)
bleu_dataset = _get_dataset(bleu_dictionary, bleu_parlai_dict)
sample_index = range(len(bleu_dataset))
for data_idx in sample_index:
net_context, _ = net_dataset[data_idx][:2]
bleu_context, bleu_sentence = bleu_dataset[data_idx][:2]
target_tokens = bleu_sentence
if args.fasttext is not None:
target_tokens = target_tokens[args.fasttext :]
context = bleu_parlai_dict.vec2txt(bleu_context.numpy().tolist())
responses, sources = predict(net_context)
response = responses[0][0]
source = sources[0]
if source == "Reddit":
source_ct[0] += 1
elif source == "EmpChat":
source_ct[1] += 1
else:
source_ct[2] += 1
if args.task == "empchat":
cid, sid = bleu_dataset.getid(data_idx)
else:
cid = sid = -1
# This is a hack, because the other datasets have no .getid() method
if args.fasttext is not None:
response = " ".join(response.split()[args.fasttext :])
outf.write("\t".join([str(cid), str(sid), context, response, source]) + "\n")
hypo_tokens = torch.IntTensor(bleu_parlai_dict.txt2vec(response))
# Use this tokenization even if a BERT tokenizer exists, to match the BLEU
# calculation when not using BERT
scorer.add(target_tokens.type(torch.IntTensor), hypo_tokens)
print(scorer.result_string(order=1))
print(scorer.result_string(order=2))
print(scorer.result_string(order=3))
print(scorer.result_string(order=4))
print(actual_ct)
print(
f"EmpatheticDialogues {int(source_ct[1]):d}: selected "
f"{float(source_ct[1]) / sum(source_ct)}%, but total: "
f"{float(actual_ct[1]) / sum(actual_ct)}"
)
print(
f"DailyDialog {int(source_ct[2]):d}: selected "
f"{float(source_ct[2]) / sum(source_ct)}%, but total: "
f"{float(actual_ct[2]) / sum(actual_ct)}"
)
print(
f"Reddit {int(source_ct[0]):d}: selected "
f"{float(source_ct[0]) / sum(source_ct)}%, but total: "
f"{float(actual_ct[0]) / sum(actual_ct)}"
)