in XLM/train.py [0:0]
def main(params):
# initialize the multi-GPU / multi-node training
init_distributed_mode(params)
# initialize the experiment
logger = initialize_exp(params)
# initialize SLURM signal handler for time limit / pre-emption
init_signal_handler()
# load data
data = load_data(params)
# build model
if params.encoder_only:
model = build_model(params, data['dico'])
else:
encoder, decoder = build_model(params, data['dico'])
# build trainer, reload potential checkpoints / build evaluator
if params.encoder_only:
trainer = SingleTrainer(model, data, params)
evaluator = SingleEvaluator(trainer, data, params)
else:
trainer = EncDecTrainer(encoder, decoder, data, params)
evaluator = EncDecEvaluator(trainer, data, params)
# evaluation
if params.eval_only:
scores = evaluator.run_all_evals(trainer)
for k, v in scores.items():
logger.info("%s -> %.6f" % (k, v))
logger.info("__log__:%s" % json.dumps(scores))
exit()
# set sampling probabilities for training
set_sampling_probs(data, params)
# language model training
for _ in range(params.max_epoch):
logger.info("============ Starting epoch %i ... ============" %
trainer.epoch)
trainer.n_sentences = 0
while trainer.n_sentences < trainer.epoch_size:
# CLM steps
for lang1, lang2 in shuf_order(params.clm_steps, params):
trainer.clm_step(lang1, lang2, params.lambda_clm)
# MLM steps (also includes TLM if lang2 is not None)
for lang1, lang2 in shuf_order(params.mlm_steps, params):
trainer.mlm_step(lang1, lang2, params.lambda_mlm)
# denoising auto-encoder steps
for lang in shuf_order(params.ae_steps):
trainer.mt_step(lang, lang, params.lambda_ae)
# machine translation steps
for lang1, lang2 in shuf_order(params.mt_steps, params):
trainer.mt_step(lang1, lang2, params.lambda_mt)
# back-translation steps
for lang1, lang2, lang3 in shuf_order(params.bt_steps):
trainer.bt_step(lang1, lang2, lang3,
params.lambda_bt, params.bt_sample_temperature)
trainer.iter()
logger.info("============ End of epoch %i ============" %
trainer.epoch)
# evaluate perplexity
scores = evaluator.run_all_evals(trainer)
# print / JSON log
for k, v in scores.items():
logger.info("%s -> %.6f" % (k, v))
if params.is_master:
logger.info("__log__:%s" % json.dumps(scores))
# end of epoch
if params.validation_metrics != '':
trainer.save_best_model(scores)
trainer.save_periodic()
trainer.end_epoch(scores)