in mmf/models/mesh_renderer.py [0:0]
def forward_losses(self, sample_list, xy_offset, z_grid, rendering_results):
z_grid_l1_0 = None
if self.loss_weights["z_grid_l1_0"] != 0:
z_grid_l1_0 = self.loss_z_grid_l1(
z_grid_pred=z_grid, depth_gt=sample_list.depth_0,
depth_loss_mask=sample_list.depth_mask_0.float()
)
losses_unscaled = {
"z_grid_l1_0": z_grid_l1_0,
"grid_offset": self.loss_grid_offset(xy_offset),
}
use_vgg19_loss = self.training or not self.config.vgg19_loss_only_on_train
if not self.config.train_z_grid_only:
rgba_0_rec, rgba_1_rec = rendering_results["rgba_out_rec_list"]
depth_0_rec, depth_1_rec = rendering_results["depth_out_rec_list"]
scaled_verts = rendering_results["scaled_verts"]
rgb_1_rec = rgba_1_rec[..., :3]
depth_l1_0 = None
if self.loss_weights["depth_l1_0"] != 0:
depth_l1_0 = self.loss_depth_l1(
depth_pred=depth_0_rec, depth_gt=sample_list.depth_0,
loss_mask=sample_list.depth_mask_0.float()
)
depth_l1_1 = None
if self.loss_weights["depth_l1_1"] != 0:
depth_l1_1 = self.loss_depth_l1(
depth_pred=depth_1_rec, depth_gt=sample_list.depth_1,
loss_mask=sample_list.depth_mask_1.float()
)
image_l1_1 = self.loss_image_l1(
rgb_pred=rgb_1_rec, rgb_gt=sample_list.orig_img_1,
loss_mask=sample_list.depth_mask_1.float()
)
if use_vgg19_loss and self.loss_weights["vgg19_perceptual_1"] != 0:
vgg19_perceptual_1 = self.loss_vgg19_perceptual(
rgb_pred=rgb_1_rec, rgb_gt=sample_list.orig_img_1,
loss_mask=sample_list.depth_mask_1.float()
)
else:
vgg19_perceptual_1 = torch.tensor(0., device=rgb_1_rec.device)
losses_unscaled.update({
"depth_l1_0": depth_l1_0,
"depth_l1_1": depth_l1_1,
"image_l1_1": image_l1_1,
"vgg19_perceptual_1": vgg19_perceptual_1,
"mesh_laplacian": self.loss_mesh_laplacian(scaled_verts),
})
if self.config.use_inpainting:
rgb_1_inpaint = rendering_results["rgb_1_inpaint"]
image_l1_1_inpaint = self.loss_image_l1(
rgb_pred=rgb_1_inpaint, rgb_gt=sample_list.orig_img_1,
)
if use_vgg19_loss and self.loss_weights["vgg19_perceptual_1_inpaint"] != 0:
vgg19_perceptual_1_inpaint = self.loss_vgg19_perceptual(
rgb_pred=rgb_1_inpaint, rgb_gt=sample_list.orig_img_1,
)
else:
vgg19_perceptual_1_inpaint = torch.tensor(0., device=rgb_1_rec.device)
losses_unscaled.update({
"image_l1_1_inpaint": image_l1_1_inpaint,
"vgg19_perceptual_1_inpaint": vgg19_perceptual_1_inpaint,
})
if self.use_discriminator:
g_losses = self.mesh_gan_losses(
fake_img=rgb_1_inpaint, real_img=sample_list.orig_img_1,
alpha_mask=rgba_1_rec[..., 3:4].ge(1e-4).float(),
update_discriminator=self.training
)
losses_unscaled.update(g_losses)
for k, v in losses_unscaled.items():
if (v is not None) and (not torch.all(torch.isfinite(v)).item()):
raise Exception("loss {} becomes {}".format(k, v.mean().item()))
losses = {
f"{sample_list.dataset_type}/{sample_list.dataset_name}/{k}":
(v * self.loss_weights[k])
for k, v in losses_unscaled.items() if self.loss_weights[k] != 0
}
return losses