def __init__()

in NMT/src/trainer.py [0:0]


    def __init__(self, encoder, decoder, discriminator, lm, data, params):
        """
        Initialize trainer.
        """
        super().__init__(device_ids=tuple(range(params.otf_num_processes)))
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator
        self.lm = lm
        self.data = data
        self.params = params

        # initialization for on-the-fly generation/training
        if len(params.pivo_directions) > 0:
            self.otf_start_multiprocessing()

        # define encoder parameters (the ones shared with the
        # decoder are optimized by the decoder optimizer)
        enc_params = list(encoder.parameters())
        for i in range(params.n_langs):
            if params.share_lang_emb and i > 0:
                break
            assert enc_params[i].size() == (params.n_words[i], params.emb_dim)
        if self.params.share_encdec_emb:
            to_ignore = 1 if params.share_lang_emb else params.n_langs
            enc_params = enc_params[to_ignore:]

        # optimizers
        if params.dec_optimizer == 'enc_optimizer':
            params.dec_optimizer = params.enc_optimizer
        self.enc_optimizer = get_optimizer(enc_params, params.enc_optimizer) if len(enc_params) > 0 else None
        self.dec_optimizer = get_optimizer(decoder.parameters(), params.dec_optimizer)
        self.dis_optimizer = get_optimizer(discriminator.parameters(), params.dis_optimizer) if discriminator is not None else None
        self.lm_optimizer = get_optimizer(lm.parameters(), params.enc_optimizer) if lm is not None else None

        # models / optimizers
        self.model_opt = {
            'enc': (self.encoder, self.enc_optimizer),
            'dec': (self.decoder, self.dec_optimizer),
            'dis': (self.discriminator, self.dis_optimizer),
            'lm': (self.lm, self.lm_optimizer),
        }

        # define validation metrics / stopping criterion used for early stopping
        logger.info("Stopping criterion: %s" % params.stopping_criterion)
        if params.stopping_criterion == '':
            for lang1, lang2 in self.data['para'].keys():
                for data_type in ['valid', 'test']:
                    self.VALIDATION_METRICS.append('bleu_%s_%s_%s' % (lang1, lang2, data_type))
            for lang1, lang2, lang3 in self.params.pivo_directions:
                if lang1 == lang3:
                    continue
                for data_type in ['valid', 'test']:
                    self.VALIDATION_METRICS.append('bleu_%s_%s_%s_%s' % (lang1, lang2, lang3, data_type))
            self.stopping_criterion = None
            self.best_stopping_criterion = None
        else:
            split = params.stopping_criterion.split(',')
            assert len(split) == 2 and split[1].isdigit()
            self.decrease_counts_max = int(split[1])
            self.decrease_counts = 0
            self.stopping_criterion = split[0]
            self.best_stopping_criterion = -1e12
            assert len(self.VALIDATION_METRICS) == 0
            self.VALIDATION_METRICS.append(self.stopping_criterion)

        # training variables
        self.best_metrics = {metric: -1e12 for metric in self.VALIDATION_METRICS}
        self.epoch = 0
        self.n_total_iter = 0
        self.freeze_enc_emb = self.params.freeze_enc_emb
        self.freeze_dec_emb = self.params.freeze_dec_emb

        # training statistics
        self.n_iter = 0
        self.n_sentences = 0
        self.stats = {
            'dis_costs': [],
            'processed_s': 0,
            'processed_w': 0,
        }
        for lang in params.mono_directions:
            self.stats['xe_costs_%s_%s' % (lang, lang)] = []
        for lang1, lang2 in params.para_directions:
            self.stats['xe_costs_%s_%s' % (lang1, lang2)] = []
        for lang1, lang2 in params.back_directions:
            self.stats['xe_costs_bt_%s_%s' % (lang1, lang2)] = []
        for lang1, lang2, lang3 in params.pivo_directions:
            self.stats['xe_costs_%s_%s_%s' % (lang1, lang2, lang3)] = []
        for lang in params.langs:
            self.stats['lme_costs_%s' % lang] = []
            self.stats['lmd_costs_%s' % lang] = []
            self.stats['lmer_costs_%s' % lang] = []
            self.stats['enc_norms_%s' % lang] = []
        self.last_time = time.time()
        if len(params.pivo_directions) > 0:
            self.gen_time = 0

        # data iterators
        self.iterators = {}

        # initialize BPE subwords
        self.init_bpe()

        # initialize lambda coefficients and their configurations
        parse_lambda_config(params, 'lambda_xe_mono')
        parse_lambda_config(params, 'lambda_xe_para')
        parse_lambda_config(params, 'lambda_xe_back')
        parse_lambda_config(params, 'lambda_xe_otfd')
        parse_lambda_config(params, 'lambda_xe_otfa')
        parse_lambda_config(params, 'lambda_dis')
        parse_lambda_config(params, 'lambda_lm')