def vizualize_translated_files()

in XLM/src/utils.py [0:0]


def vizualize_translated_files(lang1, lang2, src_file, hyp_file, ids, ref_file=None, out_file=None):
    src_viz = str(Path(src_file).with_suffix('.vizualize.txt'))
    hyp_viz = str(Path(re.sub('beam\d', '', hyp_file[0])).with_suffix(
        '.vizualize.txt.tmp'))
    if ref_file is None:
        ref_viz = str(Path('ref_tmp').with_suffix('.vizualize.txt'))
    else:
        ref_viz = str(Path(ref_file).with_suffix('.vizualize.txt'))
    if out_file is None:
        out_viz = str(Path('out_tmp').with_suffix('.vizualize.txt'))
    else:
        out_viz = str(
            Path(re.sub('beam\d', '', out_file[0])).with_suffix('.vizualize.txt'))

    ids = open(ids, 'r', encoding='utf-8').readlines()

    hyp_lines = list(zip(*[read_file_lines(path)
                           for path in hyp_file]))  # test_size * beam_size
    beam_size = len(hyp_lines[0])

    with open(src_file, encoding='utf-8') as f:
        src_lines = f.readlines()           # test_size

    if ref_file is not None:
        with open(ref_file, encoding='utf-8') as f:
            ref_lines = f.readlines()           # test_size
    else:
        ref_lines = ['' for _ in range(len(src_lines))]

    if out_file is not None:
        out_lines = list(zip(*[read_file_lines(path)
                               for path in out_file]))  # test_size * beam_size
    else:
        out_lines = [['' for n in range(len(hyp_lines[0]))]
                     for l in range(len(src_lines))]

    with open(src_viz, 'w', encoding='utf-8') as src_vizf:
        with open(hyp_viz, 'w', encoding='utf-8') as hyp_vizf:
            with open(ref_viz, 'w', encoding='utf-8') as ref_vizf:
                with open(out_viz, 'w', encoding='utf-8') as out_vizf:
                    src_vizf.write(
                        '========================SOURCE============================\n')
                    hyp_vizf.write(
                        '=========================HYPO=============================\n')
                    ref_vizf.write(
                        '==========================REF=============================\n')
                    out_vizf.write(
                        '==========================OUT=============================\n')

                    for src, hyps, ref, outs, i in zip(src_lines, hyp_lines, ref_lines, out_lines, ids):
                        src_vizf.write(
                            '=========================================================\n')
                        hyp_vizf.write(
                            '=========================================================\n')
                        ref_vizf.write(
                            '=========================================================\n')
                        out_vizf.write(
                            '=========================================================\n')
                        src_vizf.write(f'{i}')
                        hyp_vizf.write(f'{i}')
                        ref_vizf.write(f'{i}')
                        out_vizf.write(f'{i}')
                        src_vizf.write('--\n')
                        hyp_vizf.write('--\n')
                        ref_vizf.write('--\n')
                        out_vizf.write('--\n')

                        try:
                            detokenize = getattr(
                                code_tokenizer, f"detokenize_{lang1.split('_')[0]}")
                            src = detokenize(src)
                            src_vizf.write(src)
                        except:
                            src = ''.join(
                                [c if (i + 1) % 50 != 0 else c + '\n' for i, c in enumerate(src)])
                            src_vizf.write(src)

                        try:
                            detokenize = getattr(
                                code_tokenizer, f"detokenize_{lang2.split('_')[0]}")
                            ref = detokenize(ref)
                            ref_vizf.write(ref)
                        except:
                            ref = ''.join(
                                [c if (i + 1) % 50 != 0 else c + '\n' for i, c in enumerate(ref)])
                            ref_vizf.write(ref)

                        for i in range(beam_size):
                            hyp = hyps[i]
                            out = outs[i]
                            try:
                                detokenize = getattr(
                                    code_tokenizer, f"detokenize_{lang2.split('_')[0]}")
                                hyp = detokenize(hyp)
                                hyp_vizf.write(hyp)
                            except:
                                hyp = ''.join(
                                    [c if (i + 1) % 50 != 0 else c + '\n' for i, c in enumerate(hyp)])
                                hyp_vizf.write(hyp)

                            out = ''.join(
                                [c if (i + 1) % 50 != 0 else c + '\n' for i, c in enumerate(out)])
                            out_vizf.write(out)

                            if i == 0:
                                maximum = max(len(src.split('\n')), len(hyp.split('\n')), len(
                                    ref.split('\n')), len(out.split('\n')))
                                for i in range(len(src.split('\n')), maximum):
                                    src_vizf.write('\n')
                                for i in range(len(hyp.split('\n')), maximum):
                                    hyp_vizf.write('\n')
                                for i in range(len(ref.split('\n')), maximum):
                                    ref_vizf.write('\n')
                                for i in range(len(out.split('\n')), maximum):
                                    out_vizf.write('\n')
                            else:
                                maximum = max(len(hyp.split('\n')),
                                              len(out.split('\n')))
                                for i in range(maximum - 1):
                                    src_vizf.write('\n')
                                for i in range(maximum - 1):
                                    ref_vizf.write('\n')
                                for i in range(len(hyp.split('\n')), maximum):
                                    hyp_vizf.write('\n')
                                for i in range(len(out.split('\n')), maximum):
                                    out_vizf.write('\n')
                            src_vizf.write('-\n')
                            hyp_vizf.write('-\n')
                            ref_vizf.write('-\n')
                            out_vizf.write('-\n')

                        src_vizf.write('--\n\n')
                        hyp_vizf.write('--\n\n')
                        ref_vizf.write('--\n\n')
                        out_vizf.write('--\n\n')

    command = f"pr -w 250 -m -t {src_viz} {ref_viz} {hyp_viz} {out_viz} > {hyp_viz[:-4]}"
    subprocess.Popen(command, shell=True, stdout=subprocess.PIPE,
                     stderr=subprocess.PIPE).wait()

    os.remove(src_viz)
    os.remove(ref_viz)
    os.remove(hyp_viz)
    os.remove(out_viz)