in NMT/src/model/pretrain_embeddings.py [0:0]
def initialize_embeddings(encoder, decoder, params, data):
"""
Initialize the model with pretrained embeddings.
"""
if params.pretrained_emb == '':
return
split = params.pretrained_emb.split(',')
if len(split) == 1:
assert os.path.isfile(params.pretrained_emb)
pretrained_0, word2id_0 = reload_embeddings(params.pretrained_emb, params.emb_dim)
pretrained = [pretrained_0 for _ in range(params.n_langs)]
word2id = [word2id_0 for _ in range(params.n_langs)]
else:
assert len(split) == params.n_langs
assert not params.share_lang_emb
assert all(os.path.isfile(x) for x in split)
pretrained = []
word2id = []
for path in split:
pretrained_i, word2id_i = reload_embeddings(path, params.emb_dim)
pretrained.append(pretrained_i)
word2id.append(word2id_i)
assert not params.share_lang_emb or all(data['dico'][params.langs[i]] == data['dico'][params.langs[0]] for i in range(1, params.n_langs))
found = [0 for _ in range(params.n_langs)]
lower = [0 for _ in range(params.n_langs)]
# for every language
for i, lang in enumerate(params.langs):
# if everything is shared, we just need to do it for the first language
if params.share_lang_emb and i > 0:
break
# define dictionary / parameters to update
dico = data['dico'][lang]
to_update = [encoder.embeddings[i].weight.data]
if not params.share_encdec_emb:
to_update.append(decoder.embeddings[i].weight.data)
if not params.share_decpro_emb and params.pretrained_out:
to_update.append(decoder.proj[i].weight.data)
# for every word in that language
for word_id in range(params.n_words[i]):
word = dico[word_id]
if word in word2id[i]:
found[i] += 1
vec = torch.from_numpy(pretrained[i][word2id[i][word]]).cuda()
for x in to_update:
x[word_id] = vec
elif word.lower() in word2id[i]:
found[i] += 1
lower[i] += 1
vec = torch.from_numpy(pretrained[i][word2id[i][word.lower()]]).cuda()
for x in to_update:
x[word_id] = vec
# print summary
for i, lang in enumerate(params.langs):
_found = found[0 if params.share_lang_emb else i]
_lower = lower[0 if params.share_lang_emb else i]
logger.info(
"Initialized %i / %i word embeddings for \"%s\" (including %i "
"after lowercasing)." % (_found, params.n_words[i], lang, _lower)
)