in NMT/main.py [0:0]
def main(params):
# check parameters
assert params.exp_name
check_all_data_params(params)
check_mt_model_params(params)
# initialize experiment / load data / build model
logger = initialize_exp(params)
data = load_data(params)
encoder, decoder, discriminator, lm = build_mt_model(params, data)
# initialize trainer / reload checkpoint / initialize evaluator
trainer = TrainerMT(encoder, decoder, discriminator, lm, data, params)
trainer.reload_checkpoint()
trainer.test_sharing() # check parameters sharing
evaluator = EvaluatorMT(trainer, data, params)
# evaluation mode
if params.eval_only:
evaluator.run_all_evals(0)
exit()
# language model pretraining
if params.lm_before > 0:
logger.info("Pretraining language model for %i iterations ..." % params.lm_before)
trainer.n_sentences = 0
for _ in range(params.lm_before):
for lang in params.langs:
trainer.lm_step(lang)
trainer.iter()
# define epoch size
if params.epoch_size == -1:
params.epoch_size = params.n_para
assert params.epoch_size > 0
# start training
for _ in range(trainer.epoch, params.max_epoch):
logger.info("====================== Starting epoch %i ... ======================" % trainer.epoch)
trainer.n_sentences = 0
while trainer.n_sentences < params.epoch_size:
# discriminator training
for _ in range(params.n_dis):
trainer.discriminator_step()
# language model training
if params.lambda_lm > 0:
for _ in range(params.lm_after):
for lang in params.langs:
trainer.lm_step(lang)
# MT training (parallel data)
if params.lambda_xe_para > 0:
for lang1, lang2 in params.para_directions:
trainer.enc_dec_step(lang1, lang2, params.lambda_xe_para)
# MT training (back-parallel data)
if params.lambda_xe_back > 0:
for lang1, lang2 in params.back_directions:
trainer.enc_dec_step(lang1, lang2, params.lambda_xe_back, back=True)
# autoencoder training (monolingual data)
if params.lambda_xe_mono > 0:
for lang in params.mono_directions:
trainer.enc_dec_step(lang, lang, params.lambda_xe_mono)
# AE - MT training (on the fly back-translation)
if params.lambda_xe_otfd > 0 or params.lambda_xe_otfa > 0:
# start on-the-fly batch generations
if not getattr(params, 'started_otf_batch_gen', False):
otf_iterator = trainer.otf_bt_gen_async()
params.started_otf_batch_gen = True
# update model parameters on subprocesses
if trainer.n_iter % params.otf_sync_params_every == 0:
trainer.otf_sync_params()
# get training batch from CPU
before_gen = time.time()
batches = next(otf_iterator)
trainer.gen_time += time.time() - before_gen
# training
for batch in batches:
lang1, lang2, lang3 = batch['lang1'], batch['lang2'], batch['lang3']
# 2-lang back-translation - autoencoding
if lang1 != lang2 == lang3:
trainer.otf_bt(batch, params.lambda_xe_otfa, params.otf_backprop_temperature)
# 2-lang back-translation - parallel data
elif lang1 == lang3 != lang2:
trainer.otf_bt(batch, params.lambda_xe_otfd, params.otf_backprop_temperature)
# 3-lang back-translation - parallel data
elif lang1 != lang2 and lang2 != lang3 and lang1 != lang3:
trainer.otf_bt(batch, params.lambda_xe_otfd, params.otf_backprop_temperature)
trainer.iter()
# end of epoch
logger.info("====================== End of epoch %i ======================" % trainer.epoch)
# evaluate discriminator / perplexity / BLEU
scores = evaluator.run_all_evals(trainer.epoch)
# print / JSON log
for k, v in scores.items():
logger.info('%s -> %.6f' % (k, v))
logger.info("__log__:%s" % json.dumps(scores))
# save best / save periodic / end epoch
trainer.save_best_model(scores)
trainer.save_periodic()
trainer.end_epoch(scores)
trainer.test_sharing()