in align/train.py [0:0]
def collect_bitext_stats(bitext_path, align_path, save_path, src_lang, trg_lang, is_reversed=False):
stats_path = save_path + '/stats.pt'
freq_path = save_path + '/freqs.pt'
if os.path.exists(stats_path):
coocc, semi_matched_coocc, matched_coocc = torch.load(stats_path)
else:
coocc = collections.defaultdict(collections.Counter)
semi_matched_coocc = collections.defaultdict(collections.Counter)
matched_coocc = collections.defaultdict(collections.Counter)
tmpdir = tempfile.TemporaryDirectory()
os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt')
os.system(f'cat {align_path} > {tmpdir.name}/aligns.txt')
bitext = open(f'{tmpdir.name}/bitext.txt').readlines()
aligns = open(f'{tmpdir.name}/aligns.txt').readlines()
tmpdir.cleanup()
assert len(bitext) == len(aligns)
bar = tqdm(bitext)
for i, item in enumerate(bar):
try:
src_sent, trg_sent = regex.split(r'\|\|\|', item.strip())
if is_reversed:
src_sent, trg_sent = trg_sent, src_sent
align = [tuple(x if not is_reversed else reversed(x)) for x in json.loads(aligns[i])['inter']]
except:
continue
if src_lang == 'zh_CN':
src_sent = to_simplified(src_sent)
if trg_lang == 'zh_CN':
trg_sent = to_simplified(trg_sent)
src_words = src_sent.lower().split()
trg_words = trg_sent.lower().split()
src_cnt = collections.Counter([x[0] for x in align])
trg_cnt = collections.Counter([x[1] for x in align])
for x, sw in enumerate(src_words):
for y, tw in enumerate(trg_words):
if (x, y) in align:
semi_matched_coocc[sw][tw] += 1
if src_cnt[x] == 1 and trg_cnt[y] == 1:
matched_coocc[sw][tw] += 1
coocc[sw][tw] += 1
torch.save((coocc, semi_matched_coocc, matched_coocc), stats_path)
if os.path.exists(freq_path):
freq_src, freq_trg = torch.load(freq_path)
else:
freq_src = collections.Counter()
freq_trg = collections.Counter()
tmpdir = tempfile.TemporaryDirectory()
os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt')
bitext = open(f'{tmpdir.name}/bitext.txt').readlines()
tmpdir.cleanup()
bar = tqdm(bitext)
for i, item in enumerate(bar):
try:
src_sent, trg_sent = regex.split(r'\|\|\|', item.strip())
if is_reversed:
src_sent, trg_sent = trg_sent, src_sent
except:
continue
if src_lang == 'zh_CN':
src_sent = to_simplified(src_sent)
if trg_lang == 'zh_CN':
trg_sent = to_simplified(trg_sent)
for w in src_sent.split():
freq_src[w] += 1
for w in trg_sent.split():
freq_trg[w] += 1
torch.save((freq_src, freq_trg), freq_path)
return coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg