def discriminator()

in model/utils/config_helper.py [0:0]


def discriminator(cfg):
    # For discriminator
    # Discriminator related (used only if)
    cfg.DISCRIMINATOR = CN()
    cfg.DISCRIMINATOR.start_iter = 100  # To control when we start training critic
    cfg.DISCRIMINATOR.dis_loss_freq = 50  # How often to use loss from discriminator
    cfg.DISCRIMINATOR.gen_loss_freq = 10

    cfg.DISCRIMINATOR.eval_loss_freq = 10  # How often to use loss from discriminator during eval
    cfg.DISCRIMINATOR.freeze_discriminator = True
    cfg.DISCRIMINATOR.truncate_backprop = False  # while sampling do not propagate gradients beyond current token
    cfg.DISCRIMINATOR.sample_chunks_mem = 1
    cfg.DISCRIMINATOR.beta_max = 100.  # TODO: temperature decay
    cfg.DISCRIMINATOR.adapt = 'no'
    cfg.DISCRIMINATOR.type = "Null"  # or cnn or Null for no discriminator or 'bert' for BERT discriminator
    cfg.DISCRIMINATOR.dis_steps = 1  # dis_step per gen_step (default 1 for bert and 5 for cnn)
    cfg.DISCRIMINATOR.tgt_len = 64
    cfg.DISCRIMINATOR.mem_len = 64
    cfg.DISCRIMINATOR.gen_loss_factor = 30  # Multiplying factor for mmd/gan loss component in generator
    cfg.DISCRIMINATOR.dis_loss_factor = 1  # Multiplying factor for mmd/gan loss component in discriminator
    cfg.DISCRIMINATOR.batch_chunk = 1
    cfg.DISCRIMINATOR.context_len = 5  # Randomly sample context length tokens from real data and use as context.
    cfg.DISCRIMINATOR.backprop_outside = True
    cfg.DISCRIMINATOR.src_mem_len = 200

    # If 0 uses first token in real data
    cfg.DISCRIMINATOR.gen_scheduler = "constant"
    cfg.DISCRIMINATOR.gen_lr_min = 0.0
    cfg.DISCRIMINATOR.gen_warmup_step = 0
    cfg.DISCRIMINATOR.gen_decay_rate = 0.5
    cfg.DISCRIMINATOR.gen_patience = 10
    cfg.DISCRIMINATOR.gen_lr = 0.00025 / 4.0

    cfg.DISCRIMINATOR.dis_scheduler = "constant"
    cfg.DISCRIMINATOR.dis_lr_min = 0.0
    cfg.DISCRIMINATOR.dis_warmup_step = 0
    cfg.DISCRIMINATOR.dis_decay_rate = 0.5
    cfg.DISCRIMINATOR.dis_patience = 10
    cfg.DISCRIMINATOR.dis_lr = 0.00025 / 4.0

    # Bert params
    cfg.DISCRIMINATOR.BERT = CN()
    cfg.DISCRIMINATOR.BERT.learning_rate = 1e-5  # Decrease learning rate since we're fine tuning
    cfg.DISCRIMINATOR.BERT.weight_decay = 0.0
    cfg.DISCRIMINATOR.BERT.adam_epsilon = 1e-8
    cfg.DISCRIMINATOR.BERT.max_grad_norm = 1.0
    cfg.DISCRIMINATOR.BERT.model_type = "bert_lm"  # or "bert_cls"
    cfg.DISCRIMINATOR.BERT.loss_type = "rsgan"  # or 'standard’,'JS', 'KL', 'hinge', 'tv', 'rsgan', 'wgan-gp', "mmd", 'ppo', 'ppo-gp'
    cfg.DISCRIMINATOR.BERT.model_path = "../BERT/checkpoint-1969000"
    cfg.DISCRIMINATOR.BERT.freeze_layers = []  # Total layers ['0', '1', '2', '3', '4']
    cfg.DISCRIMINATOR.BERT.random_weights = False  # only implemented for bert_lm

    # CNN params (Relgan)
    cfg.DISCRIMINATOR.CNN = CN()
    cfg.DISCRIMINATOR.CNN.learning_rate = 1e-4
    cfg.DISCRIMINATOR.CNN.embed_dim = 64
    cfg.DISCRIMINATOR.CNN.hidden_dim = 64
    cfg.DISCRIMINATOR.CNN.num_rep = 64
    cfg.DISCRIMINATOR.CNN.init = "uniform"
    cfg.DISCRIMINATOR.CNN.loss_type = "rsgan"  # or 'standard’,'JS', 'KL', 'hinge', 'tv', 'rsgan', 'wgan-gp', "mmd", "ppo-gp"
    return cfg