def collect_bitext_stats()

in align/train.py [0:0]


def collect_bitext_stats(bitext_path, align_path, save_path, src_lang, trg_lang, is_reversed=False):
    stats_path = save_path + '/stats.pt'
    freq_path = save_path + '/freqs.pt'
    if os.path.exists(stats_path):
        coocc, semi_matched_coocc, matched_coocc = torch.load(stats_path)
    else:
        coocc = collections.defaultdict(collections.Counter)
        semi_matched_coocc = collections.defaultdict(collections.Counter)
        matched_coocc = collections.defaultdict(collections.Counter)
        tmpdir = tempfile.TemporaryDirectory()
        os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt')
        os.system(f'cat {align_path} > {tmpdir.name}/aligns.txt')
        bitext = open(f'{tmpdir.name}/bitext.txt').readlines()
        aligns = open(f'{tmpdir.name}/aligns.txt').readlines()
        tmpdir.cleanup()
        assert len(bitext) == len(aligns)
        bar = tqdm(bitext)
        for i, item in enumerate(bar):
            try:
                src_sent, trg_sent = regex.split(r'\|\|\|', item.strip())
                if is_reversed:
                    src_sent, trg_sent = trg_sent, src_sent
                align = [tuple(x if not is_reversed else reversed(x)) for x in json.loads(aligns[i])['inter']] 
            except:
                continue
            if src_lang == 'zh_CN':
                src_sent = to_simplified(src_sent)
            if trg_lang == 'zh_CN':
                trg_sent = to_simplified(trg_sent)
            src_words = src_sent.lower().split()
            trg_words = trg_sent.lower().split()
            src_cnt = collections.Counter([x[0] for x in align])
            trg_cnt = collections.Counter([x[1] for x in align])
            for x, sw in enumerate(src_words):
                for y, tw in enumerate(trg_words):
                    if (x, y) in align:
                        semi_matched_coocc[sw][tw] += 1
                        if src_cnt[x] == 1 and trg_cnt[y] == 1:
                            matched_coocc[sw][tw] += 1
                    coocc[sw][tw] += 1
        torch.save((coocc, semi_matched_coocc, matched_coocc), stats_path)
    if os.path.exists(freq_path):
        freq_src, freq_trg = torch.load(freq_path)
    else:
        freq_src = collections.Counter()
        freq_trg = collections.Counter()
        tmpdir = tempfile.TemporaryDirectory()
        os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt')
        bitext = open(f'{tmpdir.name}/bitext.txt').readlines()
        tmpdir.cleanup()
        bar = tqdm(bitext)
        for i, item in enumerate(bar):
            try:
                src_sent, trg_sent = regex.split(r'\|\|\|', item.strip())
                if is_reversed:
                    src_sent, trg_sent = trg_sent, src_sent
            except:
                continue
            if src_lang == 'zh_CN':
                src_sent = to_simplified(src_sent)
            if trg_lang == 'zh_CN':
                trg_sent = to_simplified(trg_sent)
            for w in src_sent.split():
                freq_src[w] += 1
            for w in trg_sent.split():
                freq_trg[w] += 1
        torch.save((freq_src, freq_trg), freq_path)
    return coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg