def extract_probs()

in src/weakly_sup.py [0:0]


def extract_probs(batch, criss, lexicon_inducer, info, configs):
    matched_coocc, semi_matched_coocc, coocc, freq_src, freq_trg = info
    all_probs = list()
    for i in range(0, len(batch), configs.batch_size):
        subbatch = batch[i:i+configs.batch_size]
        src_words, trg_words = zip(*subbatch)
        src_encodings = criss.word_embed(src_words, configs.src_lang).detach()
        trg_encodings = criss.word_embed(trg_words, configs.trg_lang).detach()
        cos_sim = cos(src_encodings, trg_encodings).reshape(-1, 1)
        dot_prod = (src_encodings * trg_encodings).sum(-1).reshape(-1, 1)
        features = torch.tensor(
            [
                [
                    matched_coocc[x[0]][x[1]],
                    semi_matched_coocc[x[0]][x[1]],
                    coocc[x[0]][x[1]],
                    freq_src[x[0]], 
                    freq_trg[x[1]],
                ] for x in subbatch
            ]
        ).float().to(configs.device).reshape(-1, 5)
        features = torch.cat([cos_sim, dot_prod, features], dim=-1)
        probs = lexicon_inducer(features).squeeze(-1)
        all_probs.append(probs)
    return torch.cat(all_probs, dim=0)