in model/transformer_gan.py [0:0]
def __init__(self, cfg, vocab):
super(TransformerGAN, self).__init__()
self.ntokens = len(vocab)
self.generator = MemTransformerLM(
cfg,
self.ntokens,
vocab.vec_len,
)
# select_discriminator = args.select_discriminator, dis_cfg = dis_cfg
def create_dis_D():
if cfg.PPO.dis_D_type == "bert":
dis_D = self.create_bert_model(
cfg.DISCRIMINATOR.BERT.model_path, cfg.DISCRIMINATOR.BERT.loss_type,
cfg.DISCRIMINATOR.BERT.model_type
)
dis_D.unfreeze_idx = self.calculate_unfreeze_idx(cfg)
elif cfg.PPO.dis_D_type == "cnn": \
dis_D = RelGAN_D(
cfg.DISCRIMINATOR.CNN.embed_dim,
cfg.DISCRIMINATOR.tgt_len,
cfg.PPO.dis_D_num_rep, # Has to be 1 if used with BERT
self.ntokens,
1,
cfg=cfg,
)
return dis_D
if 'ppo' in cfg.DISCRIMINATOR.CNN.loss_type or 'ppo' in cfg.DISCRIMINATOR.BERT.loss_type:
self.dis_D = create_dis_D()
self.P0 = None
# Create discriminator
if cfg.DISCRIMINATOR.type == "bert":
# Can change d_embed
self.discriminator = self.create_bert_model(
cfg.DISCRIMINATOR.BERT.model_path, cfg.DISCRIMINATOR.BERT.loss_type, cfg.DISCRIMINATOR.BERT.model_type,
cfg.DISCRIMINATOR.BERT.random_weights
)
self.discriminator.unfreeze_idx = self.calculate_unfreeze_idx(cfg)
elif cfg.DISCRIMINATOR.type == "cnn":
self.discriminator = RelGAN_D(
cfg.DISCRIMINATOR.CNN.embed_dim,
cfg.DISCRIMINATOR.tgt_len,
cfg.DISCRIMINATOR.CNN.num_rep,
self.ntokens,
1,
cfg=cfg,
)
else:
self.discriminator = None
self.cfg = cfg
self.temperature = 1
self.vocab = vocab
self.vec_len = vocab.vec_len