in models/losses.py [0:0]
def forward(self, rgb_real, rgb_fake, depth_real, depth_fake, global_step, optimizer_idx):
rgb_real.requires_grad = True # for R1 gradient penalty
if self.concat_depth:
if depth_fake.shape[-1] != depth_real.shape[-1]:
# downscale real depth so it doesn't have more details than fake depth
depth_real = F.interpolate(depth_real, size=depth_fake.shape[-1], mode='bilinear', align_corners=False)
# then resize both depth back up to match RGB res
depth_real = F.interpolate(depth_real, size=rgb_real.shape[-1], mode='bilinear', align_corners=False)
depth_fake = F.interpolate(depth_fake, size=rgb_real.shape[-1], mode='bilinear', align_corners=False)
disc_in_real = torch.cat([rgb_real, depth_real], dim=1)
disc_in_fake = torch.cat([rgb_fake, depth_fake], dim=1)
else:
disc_in_real = rgb_real
disc_in_fake = rgb_fake
if self.aug_policy:
disc_in_real = DiffAugment(disc_in_real, normalize=True, policy=self.aug_policy)
disc_in_fake = DiffAugment(disc_in_fake, normalize=True, policy=self.aug_policy)
if optimizer_idx == 0: # optimize generator
logits_fake, _ = self.discriminator(disc_in_fake)
g_loss = self.disc_loss(logits_fake, None, mode='g')
log = {"loss_train/g_loss": g_loss.detach()}
return g_loss, log
if optimizer_idx == 1: # optimize discriminator
logits_real, recon_real = self.discriminator(disc_in_real)
logits_fake, _ = self.discriminator(disc_in_fake.detach())
disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
disc_recon_loss = F.mse_loss(disc_in_real, recon_real) * self.recon_weight
# lazy regularization so we don't need to compute grad penalty every iteration
if (global_step % self.d_reg_every == 0) and self.r1_weight > 0:
grad_penalty = r1_loss(logits_real, rgb_real)
# the 0 * logits_real is to trigger DDP allgather
# https://github.com/rosinality/stylegan2-pytorch/issues/76
grad_penalty = self.r1_weight / 2 * grad_penalty * self.d_reg_every + (0 * logits_real.sum())
else:
grad_penalty = torch.tensor(0.0)
d_loss = disc_loss + disc_recon_loss + grad_penalty
log = {
"loss_train/disc_loss": disc_loss.detach(),
"loss_train/disc_recon_loss": disc_recon_loss.detach(),
"loss_train/r1_loss": grad_penalty.detach(),
"loss_train/d_loss": d_loss.detach(),
"loss_train/logits_real": logits_real.mean().detach(),
"loss_train/logits_fake": logits_fake.mean().detach(),
}
return d_loss, log