in torchrecipes/vision/image_generation/module/infogan.py [0:0]
def info_loss_step(self) -> TrainOutput:
# Sample labels
sampled_labels = np.random.randint(0, self.n_classes, self.cur_batch_size)
# Ground truth labels
gt_labels = torch.tensor(
sampled_labels, dtype=torch.long, device=self.device, requires_grad=False
)
# Sample noise, labels and code as generator input
z = torch.tensor(
np.random.normal(0, 1, (self.cur_batch_size, self.latent_dim)),
dtype=torch.float,
device=self.device,
)
label_input = to_categorical(
sampled_labels, num_columns=self.n_classes, device=self.device
)
code_input = torch.tensor(
np.random.uniform(-1, 1, (self.cur_batch_size, self.code_dim)),
dtype=torch.float,
device=self.device,
)
gen_imgs = self.generator(z, label_input, code_input)
_, pred_label, pred_code = self.discriminator(gen_imgs)
info_loss = LAMDBA_CAT * categorical_loss(
pred_label, gt_labels
) + LAMBDA_CON * continuous_loss(pred_code, code_input)
self.log("info_loss", info_loss, prog_bar=True)
return info_loss