in align/models.py [0:0]
def get_word_aligns(self, src_sent, trg_sent, langcodes=None, fwd_dict=None, bwd_dict=None, debug=False):
l1_tokens = [self.tokenizer.tokenize(word) for word in src_sent]
l2_tokens = [self.tokenizer.tokenize(word) for word in trg_sent]
bpe_lists = [[bpe for w in sent for bpe in w] for sent in [l1_tokens, l2_tokens]]
l1_b2w_map = list()
for i, wlist in enumerate(l1_tokens):
l1_b2w_map += [i for _ in wlist]
l2_b2w_map = list()
for i, wlist in enumerate(l2_tokens):
l2_b2w_map += [i for _ in wlist]
vectors = self.get_embed(list(bpe_lists), langcodes)
sim = (cosine_similarity(vectors[0], vectors[1]) + 1.0) / 2.0
sim = self.apply_distortion(sim, self.distortion)
all_mats = dict()
fwd, bwd = self.get_alignment_matrix(sim)
if self.matching_method.find('a') != -1:
all_mats['inter'] = fwd * bwd
if self.matching_method.find('i') != -1:
all_mats['itermax'] = self.iter_max(sim)
if self.matching_method.find('m') != -1:
all_mats['mwmf'] = self.get_max_weight_match(sim)
if self.matching_method.find('f') != -1:
all_mats['fixed'] = fwd * bwd
aligns = {k: set() for k in all_mats}
for key in aligns:
for i in range(vectors[0].shape[0]):
for j in range(vectors[1].shape[0]):
if all_mats[key][i, j] > 1e-10:
aligns[key].add((l1_b2w_map[i], l2_b2w_map[j]))
if 'fixed' in aligns:
src_aligned = set([x[0] for x in aligns['fixed']])
trg_aligned = set([x[1] for x in aligns['fixed']])
candidate_alignment = list()
for i, sw in enumerate(src_sent):
sw = sw.lower()
if i not in src_aligned:
for j, tw in enumerate(trg_sent):
tw = tw.lower()
if tw in fwd_dict[sw]:
ri = i / len(src_sent)
rj = j / len(trg_sent)
if -0.2 < ri - rj < 0.2:
candidate_alignment.append((sw, tw, i, j, fwd_dict[sw][tw], 0))
for j, tw in enumerate(trg_sent):
tw = tw.lower()
if j not in trg_aligned:
for i, sw in enumerate(src_sent):
sw = sw.lower()
if sw in bwd_dict[tw]:
ri = i / len(src_sent)
rj = j / len(trg_sent)
if -0.2 < ri - rj < 0.2:
candidate_alignment.append((sw, tw, i, j, bwd_dict[tw][sw], 1))
candidate_alignment = sorted(candidate_alignment, key=lambda x: -x[-2])
for sw, tw, i, j, val, d in candidate_alignment:
if regex.match(r'\p{P}', sw) or regex.match(r'\p{P}', tw):
continue
if val < 0.05:
break
if d == 0:
if i in src_aligned:
continue
if (j not in trg_aligned) or ((i-1, j) in aligns['fixed']) or ((i+1, j) in aligns['fixed']):
aligns['fixed'].add((i, j))
src_aligned.add(i)
trg_aligned.add(j)
if debug:
print(sw, tw, i, j, val, d)
else:
if j in trg_aligned:
continue
if (i not in src_aligned) or ((i, j+1) in aligns['fixed']) or ((i, j-1) in aligns['fixed']):
aligns['fixed'].add((i, j))
src_aligned.add(i)
trg_aligned.add(j)
if debug:
print(sw, tw, i, j, val, d)
for ext in aligns:
aligns[ext] = sorted(aligns[ext])
return aligns