def discriminator_step()

in torchrecipes/vision/image_generation/module/infogan.py [0:0]


    def discriminator_step(self) -> TrainOutput:
        # Loss for real images
        real_pred, _, _ = self.discriminator(self.real_imgs)
        d_real_loss = adversarial_loss(real_pred, self.valid)

        # Loss for fake images
        fake_pred, _, _ = self.discriminator(self.gen_imgs.detach())
        d_fake_loss = adversarial_loss(fake_pred, self.fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        return d_loss