in generation/models/pix2pixHD_model.py [0:0]
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
if self.opt.label_feat: # concatenate label features
if self.opt.faster:
feat_map = self.netE.forward_fast(real_image, label.data.cuda())
else:
feat_map = self.netE.forward(real_image, label.data.cuda())
else: # concatenate instance features
if self.opt.faster:
feat_map = self.netE.forward_fast(real_image, inst_map)
else:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
loss_D_fake = loss_D_fake.unsqueeze(0)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
loss_D_real = loss_D_real.unsqueeze(0)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
loss_G_GAN = loss_G_GAN.unsqueeze(0)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
loss_G_GAN_Feat = loss_G_GAN_Feat.unsqueeze(0)
# VGG feature matching loss
loss_G_VGG = 0
loss_G_style_VGG = 0
if not self.opt.no_style_loss:
loss_G_VGG, loss_G_style_VGG = self.criterionVGG(fake_image, real_image)
# print('VGG:', loss_G_VGG, loss_G_style_VGG)
loss_G_VGG *= self.opt.lambda_feat
loss_G_style_VGG *= self.opt.lambda_style
loss_G_style_VGG = loss_G_style_VGG.unsqueeze(0)
elif not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
loss_G_VGG = loss_G_VGG.unsqueeze(0)
loss_G_VGG = loss_G_VGG.unsqueeze(0)
# L1 reconstruction loss
loss_G_recon = 0
if not self.opt.no_recon_loss:
loss_G_recon = self.criterionRecon(fake_image, real_image) * self.opt.lambda_recon # loss(input, target)
loss_G_recon = loss_G_recon.unsqueeze(0)
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_style_VGG, loss_G_recon, loss_D_real, loss_D_fake ), None if not infer else fake_image ]