in one_shot_domain_adaptation.py [0:0]
def finetune_weights_from_images(opt, img_batch, generator, latents):
latents = latents.detach()
loss = VGGLoss().cuda()
loss_l1 = torch.nn.L1Loss()
loss_l2 = torch.nn.MSELoss()
optZ = torch.optim.SGD(generator.parameters(), lr=1)
for iter in range(opt.num_iterations):
total_recLoss = 0
generated = generator.forward(latents)
# We need to downsample the output image from 1024x1024 to 256x256.
# Since we use 256 x 256 images to compute the VGG losses.
generated = F.upsample(generated, size=(256, 256), mode="bilinear")
generated.clamp(-1, 1)
if iter % 100 == 0:
res_img = generated.detach().cpu().float().numpy()
# reshape from batchxcxhxw to batchxhxwxc and scale to [0, 255].
res_img = (np.transpose(res_img, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
if opt.verbose:
imageio.imsave(
"{}/regenerated_{}.png".format(opt.output_folder, iter), res_img[0]
)
recLoss = loss(generated, img_batch)
recLoss_l1 = loss_l1(generated, img_batch) * 5
recLoss_l2 = loss_l2(generated, img_batch)
if opt.loss_type == 1:
total_loss = recLoss + recLoss_l1 * 5
elif opt.loss_type == 2:
total_loss = recLoss_l1
elif opt.loss_type == 3:
total_loss = recLoss_l2
elif opt.loss_type == 4:
total_loss = recLoss
if iter % 100 == 0:
logging.info(
"recLoss: {} recLoss_l1: {:05f}".format(
recLoss.data.item(), recLoss_l1.data.item()
)
)
total_loss = recLoss + recLoss_l1
optZ.zero_grad()
total_loss.backward(retain_graph=True)
optZ.step()
total_recLoss += total_loss.data.item()
if iter % 100 == 0:
logging.info("iter: {} loss: {:05f}".format(iter, total_recLoss))
return generator