def forward()

in separate_vae/models/separate_clothing_encoder_models.py [0:0]


    def forward(self, real_B_encoded, infer=False, sample=False, epoch=0, epoch_iter=0):
        '''Forward input image into VAE
           Args: real_B_encoded (tensor), input tensor image
                 infer (boolean), whether to return the decoded image
                 sample (boolean), epoch (int), epoch_iter (int): deprecated
           Return: loss_MSE, reconstruction loss
                   loss_kl, kl-divergence loss
                   real_B_encoded: first binarized then normalized input image tensor
                   fake_B_encoded (optional): decoded image
        '''
        # get encoded z
        real_B_encoded = self.one_hot_tensor(real_B_encoded)
        real_B_encoded = real_B_encoded - 0.5 # normalized to [-0.5, 0.5]

        # separately forward each label
        zs_encoded = torch.zeros(self.opt.batchSize, self.opt.nz * (len(self.clothing_labels)+1) ).cuda()
        list_of_mus = []
        list_of_logvars = []
        if self.use_vae:
             # Clothing related
            for count_i, label_i in enumerate(self.clothing_labels):
                z_encoded, mu, logvar = self.encode(self.Separate_encoder, real_B_encoded[:,label_i].unsqueeze(1))
                zs_encoded[:, count_i*self.opt.nz: (count_i+1)*self.opt.nz] = z_encoded
                list_of_mus.append(mu)
                list_of_logvars.append(logvar)
            # Clothing unrelated
            z_encoded, mu, logvar = self.encode(self.Together_encoder, real_B_encoded[:,self.not_clothing_labels])
            zs_encoded[:, -1*self.opt.nz:] = z_encoded
            list_of_mus.append(mu)
            list_of_logvars.append(logvar)
        else: # regular auto-encoder
            for count_i, label_i in enumerate(self.clothing_labels):
                zs_encoded[:, count_i*self.opt.nz: (count_i+1)*self.opt.nz] = self.Separate_encoder(real_B_encoded[:,label_i].unsqueeze(1))
            zs_encoded[:, -1*self.opt.nz:] = self.Together_encoder(real_B_encoded[:,self.not_clothing_labels])

        # generate fake_B_encoded
        fake_B_encoded = self.Decoder(zs_encoded)
        fake_B_encoded = fake_B_encoded - 0.5 # normalized to [-0.5, 0.5]

        # 2. KL loss
        if self.opt.lambda_kl > 0.0 and self.use_vae:
            loss_kl = 0
            for i in range(len(self.clothing_labels) + 1):
                kl_element = list_of_mus[i].pow(2).add_(list_of_logvars[i].exp()).mul_(-1).add_(1).add_(list_of_logvars[i])
                loss_kl += torch.sum(kl_element).mul_(-0.5)
            loss_kl = loss_kl * self.opt.lambda_kl
        else:
            loss_kl = 0
        # 3, reconstruction |fake_B-real_B|
        if self.opt.lambda_L1 > 0.0:
            loss_MSE = self.criterionMSE(fake_B_encoded, real_B_encoded) * self.opt.lambda_L1
        else:
            loss_MSE = 0.0
        return [loss_MSE, loss_kl], real_B_encoded, None if not infer else fake_B_encoded
        # loss_G = loss_G_L1 + loss_kl