def __init__()

in src/training.py [0:0]


    def __init__(self, ae, lat_dis, ptc_dis, clf_dis, data, params):
        """
        Trainer initialization.
        """
        # data / parameters
        self.data = data
        self.params = params

        # modules
        self.ae = ae
        self.lat_dis = lat_dis
        self.ptc_dis = ptc_dis
        self.clf_dis = clf_dis

        # optimizers
        self.ae_optimizer = get_optimizer(ae, params.ae_optimizer)
        logger.info(ae)
        logger.info('%i parameters in the autoencoder. '
                    % sum([p.nelement() for p in ae.parameters()]))
        if params.n_lat_dis:
            logger.info(lat_dis)
            logger.info('%i parameters in the latent discriminator. '
                        % sum([p.nelement() for p in lat_dis.parameters()]))
            self.lat_dis_optimizer = get_optimizer(lat_dis, params.dis_optimizer)
        if params.n_ptc_dis:
            logger.info(ptc_dis)
            logger.info('%i parameters in the patch discriminator. '
                        % sum([p.nelement() for p in ptc_dis.parameters()]))
            self.ptc_dis_optimizer = get_optimizer(ptc_dis, params.dis_optimizer)
        if params.n_clf_dis:
            logger.info(clf_dis)
            logger.info('%i parameters in the classifier discriminator. '
                        % sum([p.nelement() for p in clf_dis.parameters()]))
            self.clf_dis_optimizer = get_optimizer(clf_dis, params.dis_optimizer)

        # reload pretrained models
        if params.ae_reload:
            reload_model(ae, params.ae_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'n_layers', 'n_skip', 'attr', 'n_attr'])
        if params.lat_dis_reload:
            reload_model(lat_dis, params.lat_dis_reload,
                         ['enc_dim', 'attr', 'n_attr'])
        if params.ptc_dis_reload:
            reload_model(ptc_dis, params.ptc_dis_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'n_patch_dis_layers'])
        if params.clf_dis_reload:
            reload_model(clf_dis, params.clf_dis_reload,
                         ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'hid_dim', 'attr', 'n_attr'])

        # training statistics
        self.stats = {}
        self.stats['rec_costs'] = []
        self.stats['lat_dis_costs'] = []
        self.stats['ptc_dis_costs'] = []
        self.stats['clf_dis_costs'] = []

        # best reconstruction loss / best accuracy
        self.best_loss = 1e12
        self.best_accu = -1e12
        self.params.n_total_iter = 0