in src/train.py [0:0]
def train_batch(self, x, x_aug, dset_num):
x, x_aug = x.float(), x_aug.float()
# Optimize D - discriminator right
z = self.encoder(x)
z_logits = self.discriminator(z)
discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
loss = discriminator_right * self.args.d_lambda
self.d_optimizer.zero_grad()
loss.backward()
if self.args.grad_clip is not None:
clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)
self.d_optimizer.step()
# optimize G - reconstructs well, discriminator wrong
z = self.encoder(x_aug)
y = self.decoder(x, z)
z_logits = self.discriminator(z)
discriminator_wrong = - F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
if not (-100 < discriminator_right.data.item() < 100):
self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
self.logger.debug(f'dset_num: {dset_num}')
recon_loss = cross_entropy_loss(y, x)
self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)
self.model_optimizer.zero_grad()
loss.backward()
if self.args.grad_clip is not None:
clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
self.model_optimizer.step()
self.loss_total.add(loss.data.item())
return loss.data.item()