in models/losses/gan_loss.py [0:0]
def loss(self, input, target_is_real, for_discriminator=True):
if self.gan_mode == "original": # cross entropy loss
target_tensor = self.get_target_tensor(input, target_is_real)
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
return loss
elif self.gan_mode == "ls":
target_tensor = self.get_target_tensor(input, target_is_real)
return F.mse_loss(input, target_tensor)
elif self.gan_mode == "hinge":
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
assert (
target_is_real
), "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input)
return loss
else:
# wgan
if target_is_real:
return -input.mean()
else:
return input.mean()