in src/utils.py [0:0]
def export_embeddings(src_emb, tgt_emb, params):
"""
Export embeddings to a text or a PyTorch file.
"""
assert params.export in ["txt", "pth"]
# text file
if params.export == "txt":
src_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.tgt_lang)
# source embeddings
logger.info('Writing source embeddings to %s ...' % src_path)
with io.open(src_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % src_emb.size())
for i in range(len(params.src_dico)):
f.write(u"%s %s\n" % (params.src_dico[i], " ".join('%.5f' % x for x in src_emb[i])))
# target embeddings
logger.info('Writing target embeddings to %s ...' % tgt_path)
with io.open(tgt_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % tgt_emb.size())
for i in range(len(params.tgt_dico)):
f.write(u"%s %s\n" % (params.tgt_dico[i], " ".join('%.5f' % x for x in tgt_emb[i])))
# PyTorch file
if params.export == "pth":
src_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.tgt_lang)
logger.info('Writing source embeddings to %s ...' % src_path)
torch.save({'dico': params.src_dico, 'vectors': src_emb}, src_path)
logger.info('Writing target embeddings to %s ...' % tgt_path)
torch.save({'dico': params.tgt_dico, 'vectors': tgt_emb}, tgt_path)