in NMT/src/trainer.py [0:0]
def __init__(self, encoder, decoder, discriminator, lm, data, params):
"""
Initialize trainer.
"""
super().__init__(device_ids=tuple(range(params.otf_num_processes)))
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
self.lm = lm
self.data = data
self.params = params
# initialization for on-the-fly generation/training
if len(params.pivo_directions) > 0:
self.otf_start_multiprocessing()
# define encoder parameters (the ones shared with the
# decoder are optimized by the decoder optimizer)
enc_params = list(encoder.parameters())
for i in range(params.n_langs):
if params.share_lang_emb and i > 0:
break
assert enc_params[i].size() == (params.n_words[i], params.emb_dim)
if self.params.share_encdec_emb:
to_ignore = 1 if params.share_lang_emb else params.n_langs
enc_params = enc_params[to_ignore:]
# optimizers
if params.dec_optimizer == 'enc_optimizer':
params.dec_optimizer = params.enc_optimizer
self.enc_optimizer = get_optimizer(enc_params, params.enc_optimizer) if len(enc_params) > 0 else None
self.dec_optimizer = get_optimizer(decoder.parameters(), params.dec_optimizer)
self.dis_optimizer = get_optimizer(discriminator.parameters(), params.dis_optimizer) if discriminator is not None else None
self.lm_optimizer = get_optimizer(lm.parameters(), params.enc_optimizer) if lm is not None else None
# models / optimizers
self.model_opt = {
'enc': (self.encoder, self.enc_optimizer),
'dec': (self.decoder, self.dec_optimizer),
'dis': (self.discriminator, self.dis_optimizer),
'lm': (self.lm, self.lm_optimizer),
}
# define validation metrics / stopping criterion used for early stopping
logger.info("Stopping criterion: %s" % params.stopping_criterion)
if params.stopping_criterion == '':
for lang1, lang2 in self.data['para'].keys():
for data_type in ['valid', 'test']:
self.VALIDATION_METRICS.append('bleu_%s_%s_%s' % (lang1, lang2, data_type))
for lang1, lang2, lang3 in self.params.pivo_directions:
if lang1 == lang3:
continue
for data_type in ['valid', 'test']:
self.VALIDATION_METRICS.append('bleu_%s_%s_%s_%s' % (lang1, lang2, lang3, data_type))
self.stopping_criterion = None
self.best_stopping_criterion = None
else:
split = params.stopping_criterion.split(',')
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
self.stopping_criterion = split[0]
self.best_stopping_criterion = -1e12
assert len(self.VALIDATION_METRICS) == 0
self.VALIDATION_METRICS.append(self.stopping_criterion)
# training variables
self.best_metrics = {metric: -1e12 for metric in self.VALIDATION_METRICS}
self.epoch = 0
self.n_total_iter = 0
self.freeze_enc_emb = self.params.freeze_enc_emb
self.freeze_dec_emb = self.params.freeze_dec_emb
# training statistics
self.n_iter = 0
self.n_sentences = 0
self.stats = {
'dis_costs': [],
'processed_s': 0,
'processed_w': 0,
}
for lang in params.mono_directions:
self.stats['xe_costs_%s_%s' % (lang, lang)] = []
for lang1, lang2 in params.para_directions:
self.stats['xe_costs_%s_%s' % (lang1, lang2)] = []
for lang1, lang2 in params.back_directions:
self.stats['xe_costs_bt_%s_%s' % (lang1, lang2)] = []
for lang1, lang2, lang3 in params.pivo_directions:
self.stats['xe_costs_%s_%s_%s' % (lang1, lang2, lang3)] = []
for lang in params.langs:
self.stats['lme_costs_%s' % lang] = []
self.stats['lmd_costs_%s' % lang] = []
self.stats['lmer_costs_%s' % lang] = []
self.stats['enc_norms_%s' % lang] = []
self.last_time = time.time()
if len(params.pivo_directions) > 0:
self.gen_time = 0
# data iterators
self.iterators = {}
# initialize BPE subwords
self.init_bpe()
# initialize lambda coefficients and their configurations
parse_lambda_config(params, 'lambda_xe_mono')
parse_lambda_config(params, 'lambda_xe_para')
parse_lambda_config(params, 'lambda_xe_back')
parse_lambda_config(params, 'lambda_xe_otfd')
parse_lambda_config(params, 'lambda_xe_otfa')
parse_lambda_config(params, 'lambda_dis')
parse_lambda_config(params, 'lambda_lm')