in src/weakly_sup.py [0:0]
def extract_probs(batch, criss, lexicon_inducer, info, configs):
matched_coocc, semi_matched_coocc, coocc, freq_src, freq_trg = info
all_probs = list()
for i in range(0, len(batch), configs.batch_size):
subbatch = batch[i:i+configs.batch_size]
src_words, trg_words = zip(*subbatch)
src_encodings = criss.word_embed(src_words, configs.src_lang).detach()
trg_encodings = criss.word_embed(trg_words, configs.trg_lang).detach()
cos_sim = cos(src_encodings, trg_encodings).reshape(-1, 1)
dot_prod = (src_encodings * trg_encodings).sum(-1).reshape(-1, 1)
features = torch.tensor(
[
[
matched_coocc[x[0]][x[1]],
semi_matched_coocc[x[0]][x[1]],
coocc[x[0]][x[1]],
freq_src[x[0]],
freq_trg[x[1]],
] for x in subbatch
]
).float().to(configs.device).reshape(-1, 5)
features = torch.cat([cos_sim, dot_prod, features], dim=-1)
probs = lexicon_inducer(features).squeeze(-1)
all_probs.append(probs)
return torch.cat(all_probs, dim=0)