in xlm/evaluation/evaluator.py [0:0]
def evaluate_mlm(self, scores, data_set, lang1, lang2):
"""
Evaluate perplexity and next word prediction accuracy.
"""
params = self.params
assert data_set in ['valid', 'test']
assert lang1 in params.langs
assert lang2 in params.langs or lang2 is None
model = self.model if params.encoder_only else self.encoder
model.eval()
model = model.module if params.multi_gpu else model
rng = np.random.RandomState(0)
lang1_id = params.lang2id[lang1]
lang2_id = params.lang2id[lang2] if lang2 is not None else None
l1l2 = lang1 if lang2 is None else f"{lang1}_{lang2}"
n_words = 0
xe_loss = 0
n_valid = 0
# only save states / evaluate usage on the validation set
eval_memory = params.use_memory and data_set == 'valid' and self.params.is_master
HashingMemory.EVAL_MEMORY = eval_memory
if eval_memory:
all_mem_att = {k: [] for k, _ in self.memory_list}
for batch in self.get_iterator(data_set, lang1, lang2, stream=(lang2 is None)):
# batch
if lang2 is None:
x, lengths = batch
positions = None
langs = x.clone().fill_(lang1_id) if params.n_langs > 1 else None
else:
(sent1, len1), (sent2, len2) = batch
x, lengths, positions, langs = concat_batches(sent1, len1, lang1_id, sent2, len2, lang2_id, params.pad_index, params.eos_index, reset_positions=True)
# words to predict
x, y, pred_mask = self.mask_out(x, lengths, rng)
# cuda
x, y, pred_mask, lengths, positions, langs = to_cuda(x, y, pred_mask, lengths, positions, langs)
# forward / loss
tensor = model('fwd', x=x, lengths=lengths, positions=positions, langs=langs, causal=False)
word_scores, loss = model('predict', tensor=tensor, pred_mask=pred_mask, y=y, get_scores=True)
# update stats
n_words += len(y)
xe_loss += loss.item() * len(y)
n_valid += (word_scores.max(1)[1] == y).sum().item()
if eval_memory:
for k, v in self.memory_list:
all_mem_att[k].append((v.last_indices, v.last_scores))
# compute perplexity and prediction accuracy
ppl_name = '%s_%s_mlm_ppl' % (data_set, l1l2)
acc_name = '%s_%s_mlm_acc' % (data_set, l1l2)
scores[ppl_name] = np.exp(xe_loss / n_words) if n_words > 0 else 1e9
scores[acc_name] = 100. * n_valid / n_words if n_words > 0 else 0.
# compute memory usage
if eval_memory:
for mem_name, mem_att in all_mem_att.items():
eval_memory_usage(scores, '%s_%s_%s' % (data_set, l1l2, mem_name), mem_att, params.mem_size)