def find_latent_from_images()

in one_shot_domain_adaptation.py [0:0]


def find_latent_from_images(opt, img_batch, generator):
    """Find the latent code from a batch of images using iterative backpropagation."""
    loss_l1 = torch.nn.L1Loss()
    loss_l2 = torch.nn.MSELoss()
    loss = VGGLoss().cuda()

    cur_latents = Variable(torch.zeros(1, 18, 512).cuda())
    cur_latents.requires_grad = True
    optZ = torch.optim.SGD([cur_latents], lr=1)

    for iter in range(opt.num_iterations):
        generated = generator.forward(cur_latents)
        # We need to downsample the output image from 1024x1024 to 256x256.
        # Since we use 256 x 256 images to compute the VGG losses.
        generated = F.upsample(generated, size=(256, 256), mode="bilinear")
        generated.clamp(-1, 1)

        if iter % 100 == 0:
            res_img = generated.detach().cpu().float().numpy()
            # reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
            res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0

            if opt.verbose:
                imageio.imsave(
                    "{}/regenerated_before_opt_{}.png".format(opt.output_folder, iter),
                    res_img[0],
                )

        recLoss = loss(generated, img_batch)
        recLoss_l1 = loss_l1(generated, img_batch)
        recLoss_l2 = loss_l2(generated, img_batch)
        if opt.loss_type == 1:
            total_loss = recLoss + recLoss_l1 * 5
        elif opt.loss_type == 2:
            total_loss = recLoss_l1
        elif opt.loss_type == 3:
            total_loss = recLoss_l2
        elif opt.loss_type == 4:
            total_loss = recLoss

        optZ.zero_grad()
        total_loss.backward(retain_graph=True)
        optZ.step()
        if iter % 100 == 0:
            logging.info(
                "iter: {}. vgg_loss: {:05f}, l1_loss: {:05f}, recloss: {:05f}".format(
                    iter,
                    recLoss.data.item(),
                    recLoss_l1.data.item(),
                    total_loss.data.item(),
                )
            )
    return cur_latents