in XLM/src/evaluation/evaluator.py [0:0]
def evaluate_mt(self, scores, data_set, lang1, lang2, eval_bleu, eval_computation):
"""
Evaluate perplexity and next word prediction accuracy.
"""
params = self.params
assert data_set in ['valid', 'test']
assert lang1 in params.langs
assert lang2 in params.langs
params = params
lang1_id = params.lang2id[lang1]
lang2_id = params.lang2id[lang2]
self.eval_mode()
encoder = self.encoder[0].module if params.multi_gpu else self.encoder[0]
decoder = self.decoder[lang2_id] if params.separate_decoders else self.decoder[0]
decoder = decoder.module if params.multi_gpu else decoder
n_words = 0
xe_loss = 0
n_valid = 0
# store hypothesis to compute BLEU score
if params.eval_bleu_test_only:
datasets_for_bleu = ['test']
else:
datasets_for_bleu = ['test', 'valid']
if (eval_bleu or eval_computation) and data_set in datasets_for_bleu:
hypothesis = []
f_ids = []
for i, batch in enumerate(self.get_iterator(data_set, lang1, lang2)):
(x1, len1, ids1, len_ids1), (x2, len2, ids2, len_ids2) = batch
langs1 = x1.clone().fill_(lang1_id)
langs2 = x2.clone().fill_(lang2_id)
# target words to predict
alen = torch.arange(
len2.max(), dtype=torch.long, device=len2.device)
# do not predict anything given the last target word
pred_mask = alen[:, None] < len2[None] - 1
y = x2[1:].masked_select(pred_mask[:-1])
assert len(y) == (len2 - 1).sum().item()
# cuda
x1, len1, langs1, x2, len2, langs2, y = to_cuda(
x1, len1, langs1, x2, len2, langs2, y)
# encode source sentence
enc1 = encoder('fwd', x=x1, lengths=len1,
langs=langs1, causal=False)
enc1 = enc1.transpose(0, 1)
enc1 = enc1.half() if params.fp16 else enc1
if max(len2) > 1024:
print('remove one long sentence')
continue
# decode target sentence
dec2 = decoder('fwd', x=x2, lengths=len2, langs=langs2,
causal=True, src_enc=enc1, src_len=len1)
# loss
word_scores, loss = decoder(
'predict', tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True)
# update stats
n_words += y.size(0)
xe_loss += loss.item() * len(y)
n_valid += (word_scores.max(1)[1] == y).sum().item()
# generate translation - translate / convert to text
if (eval_bleu or eval_computation) and data_set in datasets_for_bleu:
len_v = (3 * len1 + 10).clamp(max=params.max_len)
if params.beam_size == 1:
if params.number_samples > 1:
assert params.eval_temperature is not None
generated, lengths = decoder.generate(enc1.repeat_interleave(params.number_samples, dim=0),
len1.repeat_interleave(
params.number_samples, dim=0),
lang2_id, max_len=len_v.repeat_interleave(
params.number_samples, dim=0),
sample_temperature=params.eval_temperature)
generated = generated.T.reshape(
-1, params.number_samples, generated.shape[0]).T
lengths, _ = lengths.reshape(-1,
params.number_samples).max(dim=1)
else:
generated, lengths = decoder.generate(
enc1, len1, lang2_id, max_len=len_v)
# print(f'path 1: {generated.shape}')
else:
assert params.number_samples == 1
generated, lengths = decoder.generate_beam(
enc1, len1, lang2_id, beam_size=params.beam_size,
length_penalty=params.length_penalty,
early_stopping=params.early_stopping,
max_len=len_v
)
# print(f'path 2: {generated.shape}')
hypothesis.extend(convert_to_text(
generated, lengths, self.dico, params, generate_several_reps=True))
# compute perplexity and prediction accuracy
scores['%s_%s-%s_mt_ppl' %
(data_set, lang1, lang2)] = np.exp(xe_loss / n_words)
scores['%s_%s-%s_mt_acc' %
(data_set, lang1, lang2)] = 100. * n_valid / n_words
# write hypotheses
if (eval_bleu or eval_computation) and data_set in datasets_for_bleu:
# hypothesis / reference paths
hyp_paths = []
ref_path = params.ref_paths[(lang1, lang2, data_set)]
# export sentences to hypothesis file / restore BPE segmentation
for beam_number in range(len(hypothesis[0])):
hyp_name = 'hyp{0}.{1}-{2}.{3}_beam{4}.txt'.format(
scores['epoch'], lang1, lang2, data_set, beam_number)
hyp_path = os.path.join(params.hyp_path, hyp_name)
hyp_paths.append(hyp_path)
print(f'outputing hypotheses in {hyp_path}')
with open(hyp_path, 'w', encoding='utf-8') as f:
f.write('\n'.join([hyp[beam_number]
for hyp in hypothesis]) + '\n')
restore_segmentation(hyp_path)
# check how many functions compiles + return same output as GT
if eval_computation and data_set in datasets_for_bleu:
func_run_stats, func_run_out = eval_function_output(ref_path, hyp_paths,
params.id_paths[(
lang1, lang2, data_set)],
lang2,
params.eval_scripts_folders[(
lang1, lang2, data_set)],
EVAL_SCRIPT_FOLDER[data_set],
params.retry_mistmatching_types)
out_paths = []
success_for_beam_number = [0 for i in range(len(hypothesis[0]))]
for beam_number in range(len(hypothesis[0])):
out_name = 'hyp{0}.{1}-{2}.{3}_beam{4}.out.txt'.format(
scores['epoch'], lang1, lang2, data_set, beam_number)
out_path = os.path.join(params.hyp_path, out_name)
out_paths.append(out_path)
with open(out_path, 'w', encoding='utf-8') as f:
for results_list in func_run_out:
result_for_beam = results_list[beam_number] if beam_number < len(
results_list) else ''
if result_for_beam.startswith("success"):
success_for_beam_number[beam_number] += 1
f.write((result_for_beam) + '\n')
f.write('\n')
vizualize_translated_files(lang1, lang2, params.ref_paths[(
lang2, lang1, data_set)], hyp_paths, params.id_paths[(lang1, lang2, data_set)], ref_path, out_paths)
logger.info("Computation res %s %s %s : %s" %
(data_set, lang1, lang2, json.dumps(func_run_stats)))
scores['%s_%s-%s_mt_comp_acc' % (data_set, lang1, lang2)] = func_run_stats['success'] / (
func_run_stats['total_evaluated'] if func_run_stats['total_evaluated'] else 1)
for beam_number, success_for_beam in enumerate(success_for_beam_number):
scores['%s_%s-%smt_comp_acc_contrib_beam_%i' % (data_set, lang1, lang2, beam_number)] = success_for_beam / (
func_run_stats['total_evaluated'] if func_run_stats['total_evaluated'] else 1)
for out_path in out_paths:
Path(out_path).unlink()
# compute BLEU score
if eval_bleu and data_set in datasets_for_bleu:
# evaluate BLEU score
bleu = eval_moses_bleu(ref_path, hyp_paths[0])
logger.info("BLEU %s %s : %f" % (hyp_paths[0], ref_path, bleu))
scores['%s_%s-%s_mt_bleu' % (data_set, lang1, lang2)] = bleu
if eval_computation:
for hyp_path in hyp_paths:
Path(hyp_path).unlink()