def __init__()

in model/transformer_gan.py [0:0]


    def __init__(self, cfg, vocab):

        super(TransformerGAN, self).__init__()

        self.ntokens = len(vocab)
        self.generator = MemTransformerLM(
            cfg,
            self.ntokens,
            vocab.vec_len,
        )

        # select_discriminator = args.select_discriminator, dis_cfg = dis_cfg
        def create_dis_D():
            if cfg.PPO.dis_D_type == "bert":
                dis_D = self.create_bert_model(
                    cfg.DISCRIMINATOR.BERT.model_path, cfg.DISCRIMINATOR.BERT.loss_type,
                    cfg.DISCRIMINATOR.BERT.model_type
                )
                dis_D.unfreeze_idx = self.calculate_unfreeze_idx(cfg)
            elif cfg.PPO.dis_D_type == "cnn": \
                    dis_D = RelGAN_D(
                        cfg.DISCRIMINATOR.CNN.embed_dim,
                        cfg.DISCRIMINATOR.tgt_len,
                        cfg.PPO.dis_D_num_rep,  # Has to be 1 if used with BERT
                        self.ntokens,
                        1,
                        cfg=cfg,
                    )
            return dis_D

        if 'ppo' in cfg.DISCRIMINATOR.CNN.loss_type or 'ppo' in cfg.DISCRIMINATOR.BERT.loss_type:
            self.dis_D = create_dis_D()
            self.P0 = None

        # Create discriminator
        if cfg.DISCRIMINATOR.type == "bert":
            # Can change d_embed
            self.discriminator = self.create_bert_model(
                cfg.DISCRIMINATOR.BERT.model_path, cfg.DISCRIMINATOR.BERT.loss_type, cfg.DISCRIMINATOR.BERT.model_type,
                cfg.DISCRIMINATOR.BERT.random_weights
            )
            self.discriminator.unfreeze_idx = self.calculate_unfreeze_idx(cfg)

        elif cfg.DISCRIMINATOR.type == "cnn":
            self.discriminator = RelGAN_D(
                cfg.DISCRIMINATOR.CNN.embed_dim,
                cfg.DISCRIMINATOR.tgt_len,
                cfg.DISCRIMINATOR.CNN.num_rep,
                self.ntokens,
                1,
                cfg=cfg,
            )

        else:
            self.discriminator = None

        self.cfg = cfg
        self.temperature = 1
        self.vocab = vocab
        self.vec_len = vocab.vec_len