in code/src/trainer.py [0:0]
def __init__(self, encoder, decoder, discriminator, lm, data, params):
"""
Initialize trainer.
"""
super().__init__(data, params)
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
self.lm = lm
self.data = data
self.params = params
# initialization for on-the-fly generation/training
if params.train_bt:
self.otf_start_multiprocessing()
# define encoder parameters (the ones shared with the
# decoder are optimized by the decoder optimizer)
enc_params = list(encoder.parameters())
assert enc_params[0].size() == (params.n_words, params.emb_dim)
if self.params.share_encdec_emb:
enc_params = enc_params[1:]
# optimizers
if params.dec_optimizer == 'enc_optimizer':
params.dec_optimizer = params.enc_optimizer
self.enc_optimizer = get_optimizer(enc_params, params.enc_optimizer) if len(enc_params) > 0 else None
self.dec_optimizer = get_optimizer(decoder.parameters(), params.dec_optimizer)
self.dis_optimizer = get_optimizer(discriminator.parameters(), params.dis_optimizer) if discriminator is not None else None
self.lm_optimizer = get_optimizer(lm.parameters(), params.enc_optimizer) if lm is not None else None
# models / optimizers
self.model_opt = {
'enc': (self.encoder, self.enc_optimizer),
'dec': (self.decoder, self.dec_optimizer),
'dis': (self.discriminator, self.dis_optimizer),
'lm': (self.lm, self.lm_optimizer),
}
# training variables
self.epoch = 0
self.n_total_iter = 0
self.freeze_enc_emb = self.params.freeze_enc_emb
self.freeze_dec_emb = self.params.freeze_dec_emb
# training statistics
self.n_iter = 0
self.n_sentences = 0
self.stats = {
'processed_s': 0,
'processed_w': 0,
}
self.stats['xe_ae'] = []
self.stats['xe_bt'] = []
self.stats['xe_lm'] = []
self.stats['dis_costs'] = []
self.last_time = time.time()
if params.train_bt:
self.gen_time = 0
# data iterators
self.iterators = {}
# initialize BPE subwords
self.init_bpe()
# initialize lambda coefficients / sampling temperature, and their configurations
parse_lambda_config(params, 'lambda_ae')
parse_lambda_config(params, 'lambda_bt')
parse_lambda_config(params, 'lambda_lm')
parse_lambda_config(params, 'lambda_dis')
parse_lambda_config(params, 'otf_temperature')