def get_word_aligns()

in align/models.py [0:0]


    def get_word_aligns(self, src_sent, trg_sent, langcodes=None, fwd_dict=None, bwd_dict=None, debug=False):
        l1_tokens = [self.tokenizer.tokenize(word) for word in src_sent]
        l2_tokens = [self.tokenizer.tokenize(word) for word in trg_sent]
        bpe_lists = [[bpe for w in sent for bpe in w] for sent in [l1_tokens, l2_tokens]]
        l1_b2w_map = list()
        for i, wlist in enumerate(l1_tokens):
            l1_b2w_map += [i for _ in wlist]
        l2_b2w_map = list()
        for i, wlist in enumerate(l2_tokens):
            l2_b2w_map += [i for _ in wlist]
        vectors = self.get_embed(list(bpe_lists), langcodes)
        sim = (cosine_similarity(vectors[0], vectors[1]) + 1.0) / 2.0
        sim = self.apply_distortion(sim, self.distortion)
        all_mats = dict()
        fwd, bwd = self.get_alignment_matrix(sim)
        if self.matching_method.find('a') != -1:
            all_mats['inter'] = fwd * bwd
        if self.matching_method.find('i') != -1:
            all_mats['itermax'] = self.iter_max(sim)
        if self.matching_method.find('m') != -1:
            all_mats['mwmf'] = self.get_max_weight_match(sim)
        if self.matching_method.find('f') != -1:
            all_mats['fixed'] = fwd * bwd
        aligns = {k: set() for k in all_mats}
        for key in aligns:
            for i in range(vectors[0].shape[0]):
                for j in range(vectors[1].shape[0]):
                    if all_mats[key][i, j] > 1e-10:
                        aligns[key].add((l1_b2w_map[i], l2_b2w_map[j]))
        if 'fixed' in aligns:
            src_aligned = set([x[0] for x in aligns['fixed']])
            trg_aligned = set([x[1] for x in aligns['fixed']])
            candidate_alignment = list()
            for i, sw in enumerate(src_sent):
                sw = sw.lower()
                if i not in src_aligned:
                    for j, tw in enumerate(trg_sent):
                        tw = tw.lower()
                        if tw in fwd_dict[sw]:
                            ri = i / len(src_sent)
                            rj = j / len(trg_sent)
                            if -0.2 < ri - rj < 0.2:
                                candidate_alignment.append((sw, tw, i, j, fwd_dict[sw][tw], 0))
            for j, tw in enumerate(trg_sent):
                tw = tw.lower()
                if j not in trg_aligned:
                    for i, sw in enumerate(src_sent):
                        sw = sw.lower()
                        if sw in bwd_dict[tw]:
                            ri = i / len(src_sent)
                            rj = j / len(trg_sent)
                            if -0.2 < ri - rj < 0.2:
                                candidate_alignment.append((sw, tw, i, j, bwd_dict[tw][sw], 1))
            candidate_alignment = sorted(candidate_alignment, key=lambda x: -x[-2])
            for sw, tw, i, j, val, d in candidate_alignment:
                if regex.match(r'\p{P}', sw) or regex.match(r'\p{P}', tw):
                    continue
                if val < 0.05:
                    break
                if d == 0:
                    if i in src_aligned:
                        continue
                    if (j not in trg_aligned) or ((i-1, j) in aligns['fixed']) or ((i+1, j) in aligns['fixed']):
                        aligns['fixed'].add((i, j))
                        src_aligned.add(i)
                        trg_aligned.add(j)
                        if debug:
                            print(sw, tw, i, j, val, d)
                else:
                    if j in trg_aligned:
                        continue
                    if (i not in src_aligned) or ((i, j+1) in aligns['fixed']) or ((i, j-1) in aligns['fixed']):
                        aligns['fixed'].add((i, j))
                        src_aligned.add(i)
                        trg_aligned.add(j)
                        if debug:
                            print(sw, tw, i, j, val, d)
        for ext in aligns:
            aligns[ext] = sorted(aligns[ext])
        return aligns