in align/models.py [0:0]
def align_sents(self, sent_pairs, train_file=None, **kwargs):
aligns = list()
if self.aligner_type in ['simalign', 'criss-align']:
for src, trg in tqdm(sent_pairs):
src = src.strip().split()
trg = trg.strip().split()
align_info = self.aligner.get_word_aligns(src, trg, **kwargs)
result = None
for key in align_info:
if result is None:
result = set(align_info[key])
else:
result = result.intersection(align_info[key])
aligns.append(' '.join(['-'.join([str(x) for x in item]) for item in sorted(result)]))
elif self.aligner_type == 'fastalign':
temp_dir = tempfile.TemporaryDirectory(prefix='fast-align')
with open(os.path.join(temp_dir.name, 'bitext.txt'), 'w') as fout:
for ss, ts in sent_pairs:
fout.write(ss + ' ||| ' + ts + '\n')
fout.close()
if train_file is not None:
assert os.path.exists(train_file)
os.system(f'cat {train_file} >> {temp_dir.name}/bitext.txt')
os.system(f'fast_align -d -o -v -i {temp_dir.name}/bitext.txt > {temp_dir.name}/fwd.align')
os.system(f'fast_align -d -o -v -r -i {temp_dir.name}/bitext.txt > {temp_dir.name}/bwd.align')
os.system(f'atools -i {temp_dir.name}/fwd.align -j {temp_dir.name}/bwd.align -c grow-diag-final-and > {temp_dir.name}/final.align')
aligns = [x.strip() for x in open(f'{temp_dir.name}/final.align').readlines()][:len(sent_pairs)]
elif self.aligner_type == 'giza++':
assert train_file is not None
giza_path = '/private/home/fhs/codebase/lexind/fairseq/2-word-align-final/giza-pp/GIZA++-v2/GIZA++'
temp_dir = tempfile.TemporaryDirectory(prefix='giza++')
d_src = collections.Counter()
d_trg = collections.Counter()
w2id_src = collections.defaultdict()
w2id_trg = collections.defaultdict()
for sent_pair in open(train_file):
ss, ts = regex.split(r'\|\|\|', sent_pair.lower())
for w in ss.strip().split():
d_src[w] += 1
for w in ts.strip().split():
d_trg[w] += 1
for ss, ts in sent_pairs:
ss = ss.lower()
ts = ts.lower()
for w in ss.strip().split():
d_src[w] += 1
for w in ts.strip().split():
d_trg[w] += 1
with open(os.path.join(temp_dir.name, 's.vcb'), 'w') as fout:
for i, w in enumerate(sorted(d_src.keys())):
print(i + 1, w, d_src[w], file=fout)
w2id_src[w] = i + 1
fout.close()
with open(os.path.join(temp_dir.name, 't.vcb'), 'w') as fout:
for i, w in enumerate(sorted(d_trg.keys())):
print(i + 1, w, d_trg[w], file=fout)
w2id_trg[w] = i + 1
fout.close()
with open(os.path.join(temp_dir.name, 'bitext.train'), 'w') as fout:
for sent_pair in open(train_file):
ss, ts = regex.split(r'\|\|\|', sent_pair.lower())
print(1, file=fout)
print(' '.join([str(w2id_src[x]) for x in ss.strip().split()]), file=fout)
print(' '.join([str(w2id_trg[x]) for x in ts.strip().split()]), file=fout)
fout.close()
with open(os.path.join(temp_dir.name, 'bitext.test'), 'w') as fout:
for ss, ts in sent_pairs:
ss = ss.lower()
ts = ts.lower()
print(1, file=fout)
print(' '.join([str(w2id_src[x]) for x in ss.strip().split()]), file=fout)
print(' '.join([str(w2id_trg[x]) for x in ts.strip().split()]), file=fout)
fout.close()
os.chdir(f'{temp_dir.name}')
os.system(f'{giza_path} -S {temp_dir.name}/s.vcb -T {temp_dir.name}/t.vcb -C {temp_dir.name}/bitext.train -tc {temp_dir.name}/bitext.test')
# read giza++ results
for i, line in enumerate(open(glob(f'{temp_dir.name}/*tst.A3*')[0])):
if i % 3 == 2:
align = list()
is_trg = False
is_null = False
src_idx = 0
for item in line.strip().split():
if item == '({':
is_trg = True
elif item == '})':
is_trg = False
elif is_trg:
if not is_null:
trg_idx = int(item)
align.append(f'{src_idx}-{trg_idx}')
elif item != 'NULL':
src_idx += 1
is_null = False
else:
is_null = True
aligns.append(' '.join(align))
temp_dir.cleanup()
return aligns