def main()

in NMT/main.py [0:0]


def main(params):
    # check parameters
    assert params.exp_name
    check_all_data_params(params)
    check_mt_model_params(params)

    # initialize experiment / load data / build model
    logger = initialize_exp(params)
    data = load_data(params)
    encoder, decoder, discriminator, lm = build_mt_model(params, data)

    # initialize trainer / reload checkpoint / initialize evaluator
    trainer = TrainerMT(encoder, decoder, discriminator, lm, data, params)
    trainer.reload_checkpoint()
    trainer.test_sharing()  # check parameters sharing
    evaluator = EvaluatorMT(trainer, data, params)

    # evaluation mode
    if params.eval_only:
        evaluator.run_all_evals(0)
        exit()

    # language model pretraining
    if params.lm_before > 0:
        logger.info("Pretraining language model for %i iterations ..." % params.lm_before)
        trainer.n_sentences = 0
        for _ in range(params.lm_before):
            for lang in params.langs:
                trainer.lm_step(lang)
            trainer.iter()

    # define epoch size
    if params.epoch_size == -1:
        params.epoch_size = params.n_para
    assert params.epoch_size > 0

    # start training
    for _ in range(trainer.epoch, params.max_epoch):

        logger.info("====================== Starting epoch %i ... ======================" % trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < params.epoch_size:

            # discriminator training
            for _ in range(params.n_dis):
                trainer.discriminator_step()

            # language model training
            if params.lambda_lm > 0:
                for _ in range(params.lm_after):
                    for lang in params.langs:
                        trainer.lm_step(lang)

            # MT training (parallel data)
            if params.lambda_xe_para > 0:
                for lang1, lang2 in params.para_directions:
                    trainer.enc_dec_step(lang1, lang2, params.lambda_xe_para)

            # MT training (back-parallel data)
            if params.lambda_xe_back > 0:
                for lang1, lang2 in params.back_directions:
                    trainer.enc_dec_step(lang1, lang2, params.lambda_xe_back, back=True)

            # autoencoder training (monolingual data)
            if params.lambda_xe_mono > 0:
                for lang in params.mono_directions:
                    trainer.enc_dec_step(lang, lang, params.lambda_xe_mono)

            # AE - MT training (on the fly back-translation)
            if params.lambda_xe_otfd > 0 or params.lambda_xe_otfa > 0:

                # start on-the-fly batch generations
                if not getattr(params, 'started_otf_batch_gen', False):
                    otf_iterator = trainer.otf_bt_gen_async()
                    params.started_otf_batch_gen = True

                # update model parameters on subprocesses
                if trainer.n_iter % params.otf_sync_params_every == 0:
                    trainer.otf_sync_params()

                # get training batch from CPU
                before_gen = time.time()
                batches = next(otf_iterator)
                trainer.gen_time += time.time() - before_gen

                # training
                for batch in batches:
                    lang1, lang2, lang3 = batch['lang1'], batch['lang2'], batch['lang3']
                    # 2-lang back-translation - autoencoding
                    if lang1 != lang2 == lang3:
                        trainer.otf_bt(batch, params.lambda_xe_otfa, params.otf_backprop_temperature)
                    # 2-lang back-translation - parallel data
                    elif lang1 == lang3 != lang2:
                        trainer.otf_bt(batch, params.lambda_xe_otfd, params.otf_backprop_temperature)
                    # 3-lang back-translation - parallel data
                    elif lang1 != lang2 and lang2 != lang3 and lang1 != lang3:
                        trainer.otf_bt(batch, params.lambda_xe_otfd, params.otf_backprop_temperature)

            trainer.iter()

        # end of epoch
        logger.info("====================== End of epoch %i ======================" % trainer.epoch)

        # evaluate discriminator / perplexity / BLEU
        scores = evaluator.run_all_evals(trainer.epoch)

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))

        # save best / save periodic / end epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
        trainer.test_sharing()