in torchrecipes/vision/image_generation/module/infogan.py [0:0]
def discriminator_step(self) -> TrainOutput:
# Loss for real images
real_pred, _, _ = self.discriminator(self.real_imgs)
d_real_loss = adversarial_loss(real_pred, self.valid)
# Loss for fake images
fake_pred, _, _ = self.discriminator(self.gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, self.fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
return d_loss