def main()

in codegen_sources/model/train.py [0:0]


def main(params):
    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # build model
    print_memory(logger, "before build modules")
    if params.encoder_only:
        model = build_model(params, data["dico"])
    else:
        encoder, decoder = build_model(params, data["dico"])
    print_memory(logger, "before build classifier")

    if params.use_classifier:
        classifier = build_classifier(params)
    else:
        classifier = None

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params, classifier)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)
    print_memory(logger, "after building all models")

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            if isinstance(v, list):
                logger.info("%s -> %s" % (k, json.dumps(["%.2f" % el for el in v])))
            else:
                logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" % trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:
            show_example = True if trainer.n_sentences == 0 else False

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(
                    lang1, lang2, params.lambda_mlm, show_example=show_example
                )

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(
                    lang, lang, params.lambda_ae, show_example=show_example,
                )

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(
                    lang1, lang2, params.lambda_mt, show_example=show_example,
                )

            # machine translation using spans steps
            for lang1, lang2, span in shuf_order(params.mt_spans_steps, params):
                trainer.mt_step(
                    lang1,
                    lang2,
                    params.lambda_mt,
                    span=span,
                    show_example=show_example,
                )

            # deobscuation step
            for lang1, lang2 in shuf_order(params.do_steps):
                trainer.mt_step(
                    lang1,
                    lang2,
                    params.lambda_do,
                    deobfuscate=True,
                    deobfuscate_p=1 - params.obf_proba,
                    show_example=show_example,
                )

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(
                    lang1,
                    lang2,
                    lang3,
                    params.lambda_bt,
                    params.bt_sample_temperature,
                    show_example=show_example,
                )

            # Classification
            for lang1, lang2 in shuf_order(params.classif_steps, params):
                trainer.classif_step(
                    lang1,
                    lang2,
                    getattr(params, "lambda_classif_" + "_".join((lang1, lang2))),
                )

            # Self-Labelling
            for lang1, langs2 in shuf_order(params.st_steps):
                trainer.st_step(
                    lang1, langs2, params.lambda_st, show_example=show_example,
                )

            trainer.iter()

        logger.info("============ End of epoch %i ============" % trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            if isinstance(v, list):
                logger.info("%s -> %s" % (k, json.dumps(["%.2f" % el for el in v])))
            else:
                logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        if params.validation_metrics != "":
            trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)