def info_loss_step()

in torchrecipes/vision/image_generation/module/infogan.py [0:0]


    def info_loss_step(self) -> TrainOutput:
        # Sample labels
        sampled_labels = np.random.randint(0, self.n_classes, self.cur_batch_size)

        # Ground truth labels
        gt_labels = torch.tensor(
            sampled_labels, dtype=torch.long, device=self.device, requires_grad=False
        )

        # Sample noise, labels and code as generator input
        z = torch.tensor(
            np.random.normal(0, 1, (self.cur_batch_size, self.latent_dim)),
            dtype=torch.float,
            device=self.device,
        )

        label_input = to_categorical(
            sampled_labels, num_columns=self.n_classes, device=self.device
        )
        code_input = torch.tensor(
            np.random.uniform(-1, 1, (self.cur_batch_size, self.code_dim)),
            dtype=torch.float,
            device=self.device,
        )

        gen_imgs = self.generator(z, label_input, code_input)
        _, pred_label, pred_code = self.discriminator(gen_imgs)

        info_loss = LAMDBA_CAT * categorical_loss(
            pred_label, gt_labels
        ) + LAMBDA_CON * continuous_loss(pred_code, code_input)

        self.log("info_loss", info_loss, prog_bar=True)
        return info_loss