in floresv1/scripts/translate.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', '-d', required=True, help='Path to file to translate')
parser.add_argument('--model', '-m', required=True, help='Model checkpoint')
parser.add_argument('--lenpen', default=1.2, type=float, help='Length penalty')
parser.add_argument('--beam', default=5, type=int, help='Beam size')
parser.add_argument('--max-len-a', type=float, default=0, help='max-len-a parameter when back-translating')
parser.add_argument('--max-len-b', type=int, default=200, help='max-len-b parameter when back-translating')
parser.add_argument('--cpu', type=int, default=4, help='Number of CPU for interactive.py')
parser.add_argument('--cuda-visible-device-ids', '-gids', default=None, nargs='*', help='List of cuda visible devices ids, camma separated')
parser.add_argument('--dest', help='Output path for the intermediate and translated file')
parser.add_argument('--max-tokens', type=int, default=12000, help='max tokens')
parser.add_argument('--buffer-size', type=int, default=10000, help='Buffer size')
parser.add_argument('--chunks', type=int, default=100)
parser.add_argument('--source-lang', type=str, default=None, help='Source langauge. Will inference from the model if not set')
parser.add_argument('--target-lang', type=str, default=None, help='Target langauge. Will inference from the model if not set')
parser.add_argument('--databin', type=str, default=None, help='Parallel databin. Will combine with the back-translated databin')
parser.add_argument('--sbatch-args', default='', help='Extra SBATCH arguments')
parser.add_argument('--backend', type=str, default='local', choices=['local', 'slurm'])
args = parser.parse_args()
args.cuda_visible_device_ids = args.cuda_visible_device_ids or list(range(torch.cuda.device_count()))
chkpnt = torch.load(args.model)
model_args = chkpnt['args']
if args.source_lang is None or args.target_lang is None:
args.source_lang = args.source_lang or model_args.source_lang
args.target_lang = args.target_lang or model_args.target_lang
if args.databin is None:
args.databin = args.databin or model_args.data
root_dir = os.path.dirname(os.path.realpath(__file__))
translation_dir = os.path.join(args.dest or root_dir, 'translations', f'{args.source_lang}-{args.target_lang}')
tempdir = os.path.join(translation_dir, 'splits')
os.makedirs(tempdir, exist_ok=True)
split_files = glob(f'{tempdir}/mono_data*')
if len(split_files) != args.chunks:
if len(split_files) != 0:
print("number of split files are not the same as chunks. removing files and re-split")
[os.remove(os.path.join(tempdir, f)) for f in os.listdir(tempdir)]
print("splitting files ...")
check_call(f'split -n "r/{args.chunks}" -a3 -d {args.data} {tempdir}/mono_data', shell=True)
split_files = glob(f'{tempdir}/mono_data*')
else:
print("has the same number of splitted file and the specified chunks, skip splitting file")
translated_files = []
files_to_translate = []
for file in split_files:
# skip the translation job if it's finished
output_file = get_output_file(translation_dir, file)
translated_files.append(output_file)
if check_finished(output_file):
print(f"{output_file} is translated")
continue
files_to_translate.append(file)
print(f"{len(files_to_translate)} files to translate")
translate_files(args, translation_dir, files_to_translate)
# aggregate translated files
generated_src = f'{args.dest}/generated.src'
generated_tgt = f'{args.dest}/generated.hypo'
if count_line(generated_src) != count_line(generated_tgt) or count_line(generated_src) <= 0:
print(f"aggregating translated {len(translated_files)} files")
with TempFile() as fout:
files = " ".join(translated_files)
check_call(f"cat {files}", shell=True, stdout=fout)
# strip head and make pairs
check_call(f'cat {fout.name} | grep "^S" | cut -f2 > {generated_src}', shell=True)
check_call(f'cat {fout.name} | grep "^H" | cut -f3 > {generated_tgt}', shell=True)
assert count_line(generated_src) == count_line(generated_tgt)
print(f"output generated files to {generated_src}, {generated_tgt}")