in separate_vae/models/separate_clothing_encoder_models.py [0:0]
def forward(self, real_B_encoded, infer=False, sample=False, epoch=0, epoch_iter=0):
'''Forward input image into VAE
Args: real_B_encoded (tensor), input tensor image
infer (boolean), whether to return the decoded image
sample (boolean), epoch (int), epoch_iter (int): deprecated
Return: loss_MSE, reconstruction loss
loss_kl, kl-divergence loss
real_B_encoded: first binarized then normalized input image tensor
fake_B_encoded (optional): decoded image
'''
# get encoded z
real_B_encoded = self.one_hot_tensor(real_B_encoded)
real_B_encoded = real_B_encoded - 0.5 # normalized to [-0.5, 0.5]
# separately forward each label
zs_encoded = torch.zeros(self.opt.batchSize, self.opt.nz * (len(self.clothing_labels)+1) ).cuda()
list_of_mus = []
list_of_logvars = []
if self.use_vae:
# Clothing related
for count_i, label_i in enumerate(self.clothing_labels):
z_encoded, mu, logvar = self.encode(self.Separate_encoder, real_B_encoded[:,label_i].unsqueeze(1))
zs_encoded[:, count_i*self.opt.nz: (count_i+1)*self.opt.nz] = z_encoded
list_of_mus.append(mu)
list_of_logvars.append(logvar)
# Clothing unrelated
z_encoded, mu, logvar = self.encode(self.Together_encoder, real_B_encoded[:,self.not_clothing_labels])
zs_encoded[:, -1*self.opt.nz:] = z_encoded
list_of_mus.append(mu)
list_of_logvars.append(logvar)
else: # regular auto-encoder
for count_i, label_i in enumerate(self.clothing_labels):
zs_encoded[:, count_i*self.opt.nz: (count_i+1)*self.opt.nz] = self.Separate_encoder(real_B_encoded[:,label_i].unsqueeze(1))
zs_encoded[:, -1*self.opt.nz:] = self.Together_encoder(real_B_encoded[:,self.not_clothing_labels])
# generate fake_B_encoded
fake_B_encoded = self.Decoder(zs_encoded)
fake_B_encoded = fake_B_encoded - 0.5 # normalized to [-0.5, 0.5]
# 2. KL loss
if self.opt.lambda_kl > 0.0 and self.use_vae:
loss_kl = 0
for i in range(len(self.clothing_labels) + 1):
kl_element = list_of_mus[i].pow(2).add_(list_of_logvars[i].exp()).mul_(-1).add_(1).add_(list_of_logvars[i])
loss_kl += torch.sum(kl_element).mul_(-0.5)
loss_kl = loss_kl * self.opt.lambda_kl
else:
loss_kl = 0
# 3, reconstruction |fake_B-real_B|
if self.opt.lambda_L1 > 0.0:
loss_MSE = self.criterionMSE(fake_B_encoded, real_B_encoded) * self.opt.lambda_L1
else:
loss_MSE = 0.0
return [loss_MSE, loss_kl], real_B_encoded, None if not infer else fake_B_encoded
# loss_G = loss_G_L1 + loss_kl