in model/utils/config_helper.py [0:0]
def discriminator(cfg):
# For discriminator
# Discriminator related (used only if)
cfg.DISCRIMINATOR = CN()
cfg.DISCRIMINATOR.start_iter = 100 # To control when we start training critic
cfg.DISCRIMINATOR.dis_loss_freq = 50 # How often to use loss from discriminator
cfg.DISCRIMINATOR.gen_loss_freq = 10
cfg.DISCRIMINATOR.eval_loss_freq = 10 # How often to use loss from discriminator during eval
cfg.DISCRIMINATOR.freeze_discriminator = True
cfg.DISCRIMINATOR.truncate_backprop = False # while sampling do not propagate gradients beyond current token
cfg.DISCRIMINATOR.sample_chunks_mem = 1
cfg.DISCRIMINATOR.beta_max = 100. # TODO: temperature decay
cfg.DISCRIMINATOR.adapt = 'no'
cfg.DISCRIMINATOR.type = "Null" # or cnn or Null for no discriminator or 'bert' for BERT discriminator
cfg.DISCRIMINATOR.dis_steps = 1 # dis_step per gen_step (default 1 for bert and 5 for cnn)
cfg.DISCRIMINATOR.tgt_len = 64
cfg.DISCRIMINATOR.mem_len = 64
cfg.DISCRIMINATOR.gen_loss_factor = 30 # Multiplying factor for mmd/gan loss component in generator
cfg.DISCRIMINATOR.dis_loss_factor = 1 # Multiplying factor for mmd/gan loss component in discriminator
cfg.DISCRIMINATOR.batch_chunk = 1
cfg.DISCRIMINATOR.context_len = 5 # Randomly sample context length tokens from real data and use as context.
cfg.DISCRIMINATOR.backprop_outside = True
cfg.DISCRIMINATOR.src_mem_len = 200
# If 0 uses first token in real data
cfg.DISCRIMINATOR.gen_scheduler = "constant"
cfg.DISCRIMINATOR.gen_lr_min = 0.0
cfg.DISCRIMINATOR.gen_warmup_step = 0
cfg.DISCRIMINATOR.gen_decay_rate = 0.5
cfg.DISCRIMINATOR.gen_patience = 10
cfg.DISCRIMINATOR.gen_lr = 0.00025 / 4.0
cfg.DISCRIMINATOR.dis_scheduler = "constant"
cfg.DISCRIMINATOR.dis_lr_min = 0.0
cfg.DISCRIMINATOR.dis_warmup_step = 0
cfg.DISCRIMINATOR.dis_decay_rate = 0.5
cfg.DISCRIMINATOR.dis_patience = 10
cfg.DISCRIMINATOR.dis_lr = 0.00025 / 4.0
# Bert params
cfg.DISCRIMINATOR.BERT = CN()
cfg.DISCRIMINATOR.BERT.learning_rate = 1e-5 # Decrease learning rate since we're fine tuning
cfg.DISCRIMINATOR.BERT.weight_decay = 0.0
cfg.DISCRIMINATOR.BERT.adam_epsilon = 1e-8
cfg.DISCRIMINATOR.BERT.max_grad_norm = 1.0
cfg.DISCRIMINATOR.BERT.model_type = "bert_lm" # or "bert_cls"
cfg.DISCRIMINATOR.BERT.loss_type = "rsgan" # or 'standard’,'JS', 'KL', 'hinge', 'tv', 'rsgan', 'wgan-gp', "mmd", 'ppo', 'ppo-gp'
cfg.DISCRIMINATOR.BERT.model_path = "../BERT/checkpoint-1969000"
cfg.DISCRIMINATOR.BERT.freeze_layers = [] # Total layers ['0', '1', '2', '3', '4']
cfg.DISCRIMINATOR.BERT.random_weights = False # only implemented for bert_lm
# CNN params (Relgan)
cfg.DISCRIMINATOR.CNN = CN()
cfg.DISCRIMINATOR.CNN.learning_rate = 1e-4
cfg.DISCRIMINATOR.CNN.embed_dim = 64
cfg.DISCRIMINATOR.CNN.hidden_dim = 64
cfg.DISCRIMINATOR.CNN.num_rep = 64
cfg.DISCRIMINATOR.CNN.init = "uniform"
cfg.DISCRIMINATOR.CNN.loss_type = "rsgan" # or 'standard’,'JS', 'KL', 'hinge', 'tv', 'rsgan', 'wgan-gp', "mmd", "ppo-gp"
return cfg