in one_shot_domain_adaptation.py [0:0]
def find_latent_from_images(opt, img_batch, generator):
"""Find the latent code from a batch of images using iterative backpropagation."""
loss_l1 = torch.nn.L1Loss()
loss_l2 = torch.nn.MSELoss()
loss = VGGLoss().cuda()
cur_latents = Variable(torch.zeros(1, 18, 512).cuda())
cur_latents.requires_grad = True
optZ = torch.optim.SGD([cur_latents], lr=1)
for iter in range(opt.num_iterations):
generated = generator.forward(cur_latents)
# We need to downsample the output image from 1024x1024 to 256x256.
# Since we use 256 x 256 images to compute the VGG losses.
generated = F.upsample(generated, size=(256, 256), mode="bilinear")
generated.clamp(-1, 1)
if iter % 100 == 0:
res_img = generated.detach().cpu().float().numpy()
# reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
if opt.verbose:
imageio.imsave(
"{}/regenerated_before_opt_{}.png".format(opt.output_folder, iter),
res_img[0],
)
recLoss = loss(generated, img_batch)
recLoss_l1 = loss_l1(generated, img_batch)
recLoss_l2 = loss_l2(generated, img_batch)
if opt.loss_type == 1:
total_loss = recLoss + recLoss_l1 * 5
elif opt.loss_type == 2:
total_loss = recLoss_l1
elif opt.loss_type == 3:
total_loss = recLoss_l2
elif opt.loss_type == 4:
total_loss = recLoss
optZ.zero_grad()
total_loss.backward(retain_graph=True)
optZ.step()
if iter % 100 == 0:
logging.info(
"iter: {}. vgg_loss: {:05f}, l1_loss: {:05f}, recloss: {:05f}".format(
iter,
recLoss.data.item(),
recLoss_l1.data.item(),
total_loss.data.item(),
)
)
return cur_latents