in src/dico_builder.py [0:0]
def build_dictionary(src_emb, tgt_emb, params, s2t_candidates=None, t2s_candidates=None):
"""
Build a training dictionary given current embeddings / mapping.
"""
logger.info("Building the train dictionary ...")
s2t = 'S2T' in params.dico_build
t2s = 'T2S' in params.dico_build
assert s2t or t2s
if s2t:
if s2t_candidates is None:
s2t_candidates = get_candidates(src_emb, tgt_emb, params)
if t2s:
if t2s_candidates is None:
t2s_candidates = get_candidates(tgt_emb, src_emb, params)
t2s_candidates = torch.cat([t2s_candidates[:, 1:], t2s_candidates[:, :1]], 1)
if params.dico_build == 'S2T':
dico = s2t_candidates
elif params.dico_build == 'T2S':
dico = t2s_candidates
else:
s2t_candidates = set([(a, b) for a, b in s2t_candidates.numpy()])
t2s_candidates = set([(a, b) for a, b in t2s_candidates.numpy()])
if params.dico_build == 'S2T|T2S':
final_pairs = s2t_candidates | t2s_candidates
else:
assert params.dico_build == 'S2T&T2S'
final_pairs = s2t_candidates & t2s_candidates
if len(final_pairs) == 0:
logger.warning("Empty intersection ...")
return None
dico = torch.LongTensor(list([[int(a), int(b)] for (a, b) in final_pairs]))
logger.info('New train dictionary of %i pairs.' % dico.size(0))
return dico.cuda() if params.cuda else dico