in src/models.py [0:0]
def build_model(params, with_dis):
"""
Build all components of the model.
"""
# source embeddings
src_dico, _src_emb = load_embeddings(params, source=True)
params.src_dico = src_dico
src_emb = nn.Embedding(len(src_dico), params.emb_dim, sparse=True)
src_emb.weight.data.copy_(_src_emb)
# target embeddings
if params.tgt_lang:
tgt_dico, _tgt_emb = load_embeddings(params, source=False)
params.tgt_dico = tgt_dico
tgt_emb = nn.Embedding(len(tgt_dico), params.emb_dim, sparse=True)
tgt_emb.weight.data.copy_(_tgt_emb)
else:
tgt_emb = None
# mapping
mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
if getattr(params, 'map_id_init', True):
mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
# discriminator
discriminator = Discriminator(params) if with_dis else None
# cuda
if params.cuda:
src_emb.cuda()
if params.tgt_lang:
tgt_emb.cuda()
mapping.cuda()
if with_dis:
discriminator.cuda()
# normalize embeddings
params.src_mean = normalize_embeddings(src_emb.weight.data, params.normalize_embeddings)
if params.tgt_lang:
params.tgt_mean = normalize_embeddings(tgt_emb.weight.data, params.normalize_embeddings)
return src_emb, tgt_emb, mapping, discriminator