lib/utils.py [45:60]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def dcgan_loss_gen(x_fake, netD, device):
    p_gen = netD(x_fake)
    gen_loss = F.softplus(-p_gen).mean()
    return gen_loss, p_gen


def wgan_loss_gen(x_fake, netD, device):
    score_gen = netD(x_fake)
    gen_loss = -score_gen.mean()
    return gen_loss, score_gen


def wgan_loss_dis(x_real, x_fake, netD, device):
    score_real, score_gen = netD(x_real), netD(x_fake)
    dis_loss = score_gen.mean() - score_real.mean()
    return dis_loss, score_real, score_gen
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



train_mnist.py [200:219]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def dcgan_loss_gen(x_fake, netD, device):
    p_gen = netD(x_fake)
    gen_loss = F.softplus(-p_gen).mean()
    return gen_loss, p_gen


def wgan_loss_gen(x_fake, netD, device):
    score_gen = netD(x_fake)

    gen_loss = -score_gen.mean()
    return gen_loss, score_gen


def wgan_loss_dis(x_real, x_fake, netD, device):
    score_real, score_gen = netD(x_real), netD(x_fake)

    dis_loss = score_gen.mean() - score_real.mean()
    # if grad_penalty:
    #     dis_loss += gp_lambda * netD.get_penalty(x_real.detach(), x_fake.detach())
    return dis_loss, score_real, score_gen
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



