in mmf/models/mesh_renderer.py [0:0]
def forward(self, sample_list):
if not self.config.fill_z_with_gt:
# use the transformed image (after mean subtraction and normalization) as
# network input
xy_offset, z_grid = self.offset_and_depth_predictor(sample_list.trans_img_0)
else:
xy_offset, z_grid = self.get_offset_and_depth_from_gt(sample_list)
if self.config.force_zero_xy_offset:
xy_offset = torch.zeros_like(xy_offset)
rendering_results = {}
if not self.config.train_z_grid_only:
# use the original image (RGB value in 0~1) as rendering input
rendering_results = self.novel_view_projector(
xy_offset=xy_offset,
z_grid=z_grid,
rgb_in=sample_list.orig_img_0,
R_in=sample_list.R_0,
T_in=sample_list.T_0,
R_out_list=[sample_list.R_0, sample_list.R_1],
T_out_list=[sample_list.T_0, sample_list.T_1],
render_mesh_shape=self.config.render_mesh_shape_for_vis,
)
if self.config.use_inpainting:
_, rgba_1_rec = rendering_results["rgba_out_rec_list"]
if self.config.sanity_check_inpaint_with_gt:
# as a sanity check, use the ground-truth image as input to make sure
# the generator has enough capacity to perfectly reconstruct it.
rgba_1_rec = torch.ones_like(rgba_1_rec)
rgba_1_rec[..., :3] = sample_list.orig_img_1
rgb_1_inpaint = self.inpainting_net_G(rgba_1_rec)
if self.config.inpainting.inpaint_missing_regions_only:
alpha_mask = rgba_1_rec[..., -1].unsqueeze(-1).ge(1e-4).float()
rgb_1_inpaint = rgb_1_inpaint * (1 - alpha_mask)
rgb_1_inpaint = rgb_1_inpaint + rgba_1_rec[..., :3] * alpha_mask
rendering_results["rgb_1_inpaint"] = rgb_1_inpaint
rendering_results["rgb_1_out"] = rendering_results["rgb_1_inpaint"]
else:
_, rgba_1_rec = rendering_results["rgba_out_rec_list"]
rendering_results["rgb_1_out"] = rgba_1_rec[..., :3]
# return only the rendering results and skip loss computation, usually for
# visualization on-the-fly by calling this model separately (instead of running
# it within the MMF trainer on MMF datasets)
if self.config.return_rendering_results_only:
return rendering_results
losses = self.forward_losses(sample_list, xy_offset, z_grid, rendering_results)
# compute metrics
if not self.training or not self.config.metrics.only_on_eval:
metrics_dict = self.forward_metrics(sample_list, rendering_results)
rendering_results.update(metrics_dict)
# average over batch, and do not compute gradient over metrics
losses.update({
f"{sample_list.dataset_type}/{sample_list.dataset_name}/no_grad_{k}":
v.detach().mean()
for k, v in metrics_dict.items()
})
if self.config.save_forward_results:
self.save_forward_results(sample_list, xy_offset, z_grid, rendering_results)
return {"losses": losses}