in src/evaluation/wordsim.py [0:0]
def get_wordanalogy_scores(language, word2id, embeddings, lower=True):
"""
Return (english) word analogy score
"""
dirpath = os.path.join(MONOLINGUAL_EVAL_PATH, language)
if not os.path.isdir(dirpath) or language not in ["en"]:
return None
# normalize word embeddings
embeddings = embeddings / np.sqrt((embeddings ** 2).sum(1))[:, None]
# scores by category
scores = {}
word_ids = {}
queries = {}
with io.open(os.path.join(dirpath, 'questions-words.txt'), 'r', encoding='utf-8') as f:
for line in f:
# new line
line = line.rstrip()
if lower:
line = line.lower()
# new category
if ":" in line:
assert line[1] == ' '
category = line[2:]
assert category not in scores
scores[category] = {'n_found': 0, 'n_not_found': 0, 'n_correct': 0}
word_ids[category] = []
queries[category] = []
continue
# get word IDs
assert len(line.split()) == 4, line
word1, word2, word3, word4 = line.split()
word_id1 = get_word_id(word1, word2id, lower)
word_id2 = get_word_id(word2, word2id, lower)
word_id3 = get_word_id(word3, word2id, lower)
word_id4 = get_word_id(word4, word2id, lower)
# if at least one word is not found
if any(x is None for x in [word_id1, word_id2, word_id3, word_id4]):
scores[category]['n_not_found'] += 1
continue
else:
scores[category]['n_found'] += 1
word_ids[category].append([word_id1, word_id2, word_id3, word_id4])
# generate query vector and get nearest neighbors
query = embeddings[word_id1] - embeddings[word_id2] + embeddings[word_id4]
query = query / np.linalg.norm(query)
queries[category].append(query)
# Compute score for each category
for cat in queries:
qs = torch.from_numpy(np.vstack(queries[cat]))
keys = torch.from_numpy(embeddings.T)
values = qs.mm(keys).cpu().numpy()
# be sure we do not select input words
for i, ws in enumerate(word_ids[cat]):
for wid in [ws[0], ws[1], ws[3]]:
values[i, wid] = -1e9
scores[cat]['n_correct'] = np.sum(values.argmax(axis=1) == [ws[2] for ws in word_ids[cat]])
# pretty print
separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12)
pattern = "%30s %10s %13s %12s"
logger.info(separator)
logger.info(pattern % ("Category", "Found", "Not found", "Accuracy"))
logger.info(separator)
# compute and log accuracies
accuracies = {}
for k in sorted(scores.keys()):
v = scores[k]
accuracies[k] = float(v['n_correct']) / max(v['n_found'], 1)
logger.info(pattern % (k, str(v['n_found']), str(v['n_not_found']), "%.4f" % accuracies[k]))
logger.info(separator)
return accuracies