in src/evaluation/word_translation.py [0:0]
def get_word_translation_accuracy(lang1, word2id1, emb1, lang2, word2id2, emb2, method, dico_eval):
"""
Given source and target word embeddings, and a dictionary,
evaluate the translation accuracy using the precision@k.
"""
if dico_eval == 'default':
path = os.path.join(DIC_EVAL_PATH, '%s-%s.5000-6500.txt' % (lang1, lang2))
else:
path = dico_eval
dico = load_dictionary(path, word2id1, word2id2)
dico = dico.cuda() if emb1.is_cuda else dico
assert dico[:, 0].max() < emb1.size(0)
assert dico[:, 1].max() < emb2.size(0)
# normalize word embeddings
emb1 = emb1 / emb1.norm(2, 1, keepdim=True).expand_as(emb1)
emb2 = emb2 / emb2.norm(2, 1, keepdim=True).expand_as(emb2)
# nearest neighbors
if method == 'nn':
query = emb1[dico[:, 0]]
scores = query.mm(emb2.transpose(0, 1))
# inverted softmax
elif method.startswith('invsm_beta_'):
beta = float(method[len('invsm_beta_'):])
bs = 128
word_scores = []
for i in range(0, emb2.size(0), bs):
scores = emb1.mm(emb2[i:i + bs].transpose(0, 1))
scores.mul_(beta).exp_()
scores.div_(scores.sum(0, keepdim=True).expand_as(scores))
word_scores.append(scores.index_select(0, dico[:, 0]))
scores = torch.cat(word_scores, 1)
# contextual dissimilarity measure
elif method.startswith('csls_knn_'):
# average distances to k nearest neighbors
knn = method[len('csls_knn_'):]
assert knn.isdigit()
knn = int(knn)
average_dist1 = get_nn_avg_dist(emb2, emb1, knn)
average_dist2 = get_nn_avg_dist(emb1, emb2, knn)
average_dist1 = torch.from_numpy(average_dist1).type_as(emb1)
average_dist2 = torch.from_numpy(average_dist2).type_as(emb2)
# queries / scores
query = emb1[dico[:, 0]]
scores = query.mm(emb2.transpose(0, 1))
scores.mul_(2)
scores.sub_(average_dist1[dico[:, 0]][:, None])
scores.sub_(average_dist2[None, :])
else:
raise Exception('Unknown method: "%s"' % method)
results = []
top_matches = scores.topk(10, 1, True)[1]
for k in [1, 5, 10]:
top_k_matches = top_matches[:, :k]
_matching = (top_k_matches == dico[:, 1][:, None].expand_as(top_k_matches)).sum(1).cpu().numpy()
# allow for multiple possible translations
matching = {}
for i, src_id in enumerate(dico[:, 0].cpu().numpy()):
matching[src_id] = min(matching.get(src_id, 0) + _matching[i], 1)
# evaluate precision@k
precision_at_k = 100 * np.mean(list(matching.values()))
logger.info("%i source words - %s - Precision at k = %i: %f" %
(len(matching), method, k, precision_at_k))
results.append(('precision_at_%i' % k, precision_at_k))
return results