def regenerate_img()

in one_shot_domain_adaptation.py [0:0]


def regenerate_img(opt, img_batch, generator, exp_img_batch=None, img_names=None):
    """regenerate the image using stylegan.

    Parameters
    ----------
    img_batch : numpy array
            batchsize x height x width x channels.
    exp_img_batch: numpy array
            batchsize x height x width x channels. This is optional.

    Returns:
    res_img: numpy array.
        scaled to [0, 255] and cast as uint8.
    """
    latents = find_latent_from_images(opt, img_batch, generator)

    cur_latents = latents
    with torch.no_grad():
        res_img = generator.forward(cur_latents)
        res_img.clamp(-1, 1)

    res_img = res_img.cpu().float().numpy()
    # reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
    res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0

    if opt.verbose:
        imageio.imsave(
            "{}/{}_regenerated_before_opt.png".format(opt.output_folder, img_names[0]),
            res_img[0],
        )

    # We then fix the latents and finetune weights from images
    # Set the requires grad for each layer
    reinitialize_requires_grad_diff_layers(generator, opt)
    finetune_weights_from_images(opt, img_batch, generator, latents)

    latents_numpy = latents.data.cpu().numpy()
    if opt.verbose:
        # save the latent code
        outfile = "{}/{}_latent.npy".format(opt.output_folder, img_names[0])
        np.save(outfile, latents_numpy)
        # save the fine-tuned model as well:
        torch.save(
            generator.state_dict(),
            "{}/karras2019stylegan-ffhq-1024x1024.for_g_synthesis_finetuned.pt".format(
                opt.output_folder
            ),  # noqa
        )

    # save the image
    with torch.no_grad():
        res_img = generator.forward(cur_latents)
        res_img.clamp(-1, 1)

    res_img = res_img.cpu().float().numpy()
    # reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
    res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0

    imageio.imsave(
        "{}/{}_regenerated.png".format(opt.output_folder, img_names[0]), res_img[0]
    )
    return generator, latents_numpy