def main()

in floresv1/scripts/translate.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', '-d', required=True, help='Path to file to translate')
    parser.add_argument('--model', '-m', required=True, help='Model checkpoint')
    parser.add_argument('--lenpen', default=1.2, type=float, help='Length penalty')
    parser.add_argument('--beam', default=5, type=int, help='Beam size')
    parser.add_argument('--max-len-a', type=float, default=0, help='max-len-a parameter when back-translating')
    parser.add_argument('--max-len-b', type=int, default=200, help='max-len-b parameter when back-translating')
    parser.add_argument('--cpu', type=int, default=4, help='Number of CPU for interactive.py')
    parser.add_argument('--cuda-visible-device-ids', '-gids', default=None, nargs='*', help='List of cuda visible devices ids, camma separated')
    parser.add_argument('--dest', help='Output path for the intermediate and translated file')
    parser.add_argument('--max-tokens', type=int, default=12000, help='max tokens')
    parser.add_argument('--buffer-size', type=int, default=10000, help='Buffer size')
    parser.add_argument('--chunks', type=int, default=100)
    parser.add_argument('--source-lang', type=str, default=None, help='Source langauge. Will inference from the model if not set')
    parser.add_argument('--target-lang', type=str, default=None, help='Target langauge. Will inference from the model if not set')
    parser.add_argument('--databin', type=str, default=None, help='Parallel databin. Will combine with the back-translated databin')
    parser.add_argument('--sbatch-args', default='', help='Extra SBATCH arguments')

    parser.add_argument('--backend', type=str, default='local', choices=['local', 'slurm'])
    args = parser.parse_args()

    args.cuda_visible_device_ids = args.cuda_visible_device_ids or list(range(torch.cuda.device_count()))

    chkpnt = torch.load(args.model)
    model_args = chkpnt['args']
    if args.source_lang is None or args.target_lang is None:
        args.source_lang = args.source_lang or model_args.source_lang
        args.target_lang = args.target_lang or model_args.target_lang
    if args.databin is None:
        args.databin = args.databin or model_args.data

    root_dir = os.path.dirname(os.path.realpath(__file__))
    translation_dir = os.path.join(args.dest or root_dir, 'translations', f'{args.source_lang}-{args.target_lang}')

    tempdir = os.path.join(translation_dir, 'splits')
    os.makedirs(tempdir, exist_ok=True)
    split_files = glob(f'{tempdir}/mono_data*')

    if len(split_files) != args.chunks:
        if len(split_files) != 0:
            print("number of split files are not the same as chunks. removing files and re-split")
            [os.remove(os.path.join(tempdir, f)) for f in os.listdir(tempdir)]
        print("splitting files ...")
        check_call(f'split -n "r/{args.chunks}" -a3 -d {args.data} {tempdir}/mono_data', shell=True)
        split_files = glob(f'{tempdir}/mono_data*')
    else:
        print("has the same number of splitted file and the specified chunks, skip splitting file")

    translated_files = []
    files_to_translate = []
    for file in split_files:
        # skip the translation job if it's finished
        output_file = get_output_file(translation_dir, file)
        translated_files.append(output_file)
        if check_finished(output_file):
            print(f"{output_file} is translated")
            continue
        files_to_translate.append(file)

    print(f"{len(files_to_translate)} files to translate")

    translate_files(args, translation_dir, files_to_translate)

    # aggregate translated files
    generated_src = f'{args.dest}/generated.src'
    generated_tgt = f'{args.dest}/generated.hypo'
    if count_line(generated_src) != count_line(generated_tgt) or count_line(generated_src) <= 0:
        print(f"aggregating translated {len(translated_files)} files")
        with TempFile() as fout:
            files = " ".join(translated_files)
            check_call(f"cat {files}", shell=True, stdout=fout)
            # strip head and make pairs
            check_call(f'cat {fout.name} | grep "^S" | cut -f2 > {generated_src}', shell=True)
            check_call(f'cat {fout.name} | grep "^H" | cut -f3 > {generated_tgt}', shell=True)
    assert count_line(generated_src) == count_line(generated_tgt)
    print(f"output generated files to {generated_src}, {generated_tgt}")