def finetune_weights_from_images()

in one_shot_domain_adaptation.py [0:0]


def finetune_weights_from_images(opt, img_batch, generator, latents):
    latents = latents.detach()
    loss = VGGLoss().cuda()
    loss_l1 = torch.nn.L1Loss()
    loss_l2 = torch.nn.MSELoss()
    optZ = torch.optim.SGD(generator.parameters(), lr=1)

    for iter in range(opt.num_iterations):
        total_recLoss = 0

        generated = generator.forward(latents)
        # We need to downsample the output image from 1024x1024 to 256x256.
        # Since we use 256 x 256 images to compute the VGG losses.
        generated = F.upsample(generated, size=(256, 256), mode="bilinear")
        generated.clamp(-1, 1)

        if iter % 100 == 0:
            res_img = generated.detach().cpu().float().numpy()
            # reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
            res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0

            if opt.verbose:
                imageio.imsave(
                    "{}/regenerated_{}.png".format(opt.output_folder, iter), res_img[0]
                )

        recLoss = loss(generated, img_batch)
        recLoss_l1 = loss_l1(generated, img_batch) * 5
        recLoss_l2 = loss_l2(generated, img_batch)
        if opt.loss_type == 1:
            total_loss = recLoss + recLoss_l1 * 5
        elif opt.loss_type == 2:
            total_loss = recLoss_l1
        elif opt.loss_type == 3:
            total_loss = recLoss_l2
        elif opt.loss_type == 4:
            total_loss = recLoss

        if iter % 100 == 0:
            logging.info(
                "recLoss: {} recLoss_l1: {:05f}".format(
                    recLoss.data.item(), recLoss_l1.data.item()
                )
            )
        total_loss = recLoss + recLoss_l1

        optZ.zero_grad()
        total_loss.backward(retain_graph=True)
        optZ.step()
        total_recLoss += total_loss.data.item()

        if iter % 100 == 0:
            logging.info("iter: {} loss: {:05f}".format(iter, total_recLoss))
    return generator