in src/training.py [0:0]
def __init__(self, ae, lat_dis, ptc_dis, clf_dis, data, params):
"""
Trainer initialization.
"""
# data / parameters
self.data = data
self.params = params
# modules
self.ae = ae
self.lat_dis = lat_dis
self.ptc_dis = ptc_dis
self.clf_dis = clf_dis
# optimizers
self.ae_optimizer = get_optimizer(ae, params.ae_optimizer)
logger.info(ae)
logger.info('%i parameters in the autoencoder. '
% sum([p.nelement() for p in ae.parameters()]))
if params.n_lat_dis:
logger.info(lat_dis)
logger.info('%i parameters in the latent discriminator. '
% sum([p.nelement() for p in lat_dis.parameters()]))
self.lat_dis_optimizer = get_optimizer(lat_dis, params.dis_optimizer)
if params.n_ptc_dis:
logger.info(ptc_dis)
logger.info('%i parameters in the patch discriminator. '
% sum([p.nelement() for p in ptc_dis.parameters()]))
self.ptc_dis_optimizer = get_optimizer(ptc_dis, params.dis_optimizer)
if params.n_clf_dis:
logger.info(clf_dis)
logger.info('%i parameters in the classifier discriminator. '
% sum([p.nelement() for p in clf_dis.parameters()]))
self.clf_dis_optimizer = get_optimizer(clf_dis, params.dis_optimizer)
# reload pretrained models
if params.ae_reload:
reload_model(ae, params.ae_reload,
['img_sz', 'img_fm', 'init_fm', 'n_layers', 'n_skip', 'attr', 'n_attr'])
if params.lat_dis_reload:
reload_model(lat_dis, params.lat_dis_reload,
['enc_dim', 'attr', 'n_attr'])
if params.ptc_dis_reload:
reload_model(ptc_dis, params.ptc_dis_reload,
['img_sz', 'img_fm', 'init_fm', 'max_fm', 'n_patch_dis_layers'])
if params.clf_dis_reload:
reload_model(clf_dis, params.clf_dis_reload,
['img_sz', 'img_fm', 'init_fm', 'max_fm', 'hid_dim', 'attr', 'n_attr'])
# training statistics
self.stats = {}
self.stats['rec_costs'] = []
self.stats['lat_dis_costs'] = []
self.stats['ptc_dis_costs'] = []
self.stats['clf_dis_costs'] = []
# best reconstruction loss / best accuracy
self.best_loss = 1e12
self.best_accu = -1e12
self.params.n_total_iter = 0