in floresv1/scripts/train.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', required=True, help='pipeline config')
parser.add_argument('--databin', '-d', required=True, help='initial databin')
args = parser.parse_args()
configs = read_config(args.config)
workdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../experiments')
#cuda_visible_devices=args.cuda_visible_devices or list(range(torch.cuda.device_count()))
initial_databin = args.databin
for i in range(len(configs)):
(name, config) = configs[i]
src = config['src']
tgt = config['tgt']
direction = f"{src}-{tgt}"
print(f"Start {name} iteration, {direction}")
iter_workdir = os.path.join(workdir, name, direction)
# train
model_dir = os.path.join(iter_workdir, 'model')
train(src, tgt, config['train'], model_dir, initial_databin)
checkpoint_path = os.path.join(model_dir, 'checkpoint_best.pt')
# eval
lenpen = config['translate']['lenpen']
eval_output = os.path.join(model_dir, 'eval.txt')
if check_last_line(eval_output, "BLEU"):
print(check_output(f"tail -n 1 {eval_output}", shell=True).decode('utf-8').strip())
else:
print(eval_bleu(
config['src'], config['tgt'],
'test', lenpen,
args.databin, checkpoint_path,
os.path.join(model_dir, 'eval.txt')
))
# Early exit to skip back-translation for the last iteration
if i == len(configs) - 1:
break
# translate
translate_output = os.path.join(iter_workdir, 'synthetic')
translate(config['src'], config['tgt'], checkpoint_path, lenpen, translate_output, config['translate']['mono'], config['translate']['max_token'])
# generate databin
databin_folder = os.path.join(translate_output, 'bt')
initial_databin = build_bt_databin(
config['tgt'], config['src'],
os.path.join(translate_output, 'generated'), args.databin, databin_folder
)