in models/generator.py [0:0]
def forward(self, local_latents, render_params=None, xyz=None):
assert (xyz is not None) ^ (render_params is not None), 'Use either xyz or render_params, not both.'
return_alpha_only = True if xyz is not None else False
if render_params is not None:
H, W = render_params.nerf_out_res, render_params.nerf_out_res
# if using feature-NeRF, need to adjust camera intrinsics to account for lower sampling resolution
if self.img_res is not None:
downsampling_ratio = render_params.nerf_out_res / self.img_res
else:
downsampling_ratio = 1
fx, fy = render_params.K[0, 0, 0] * downsampling_ratio, render_params.K[0, 1, 1] * downsampling_ratio
xyz, viewdirs, z_vals, rd, ro = get_sample_points(
tform_cam2world=render_params.Rt.inverse(),
F=(fx, fy),
H=H,
W=W,
samples_per_ray=render_params.samples_per_ray,
near=render_params.near,
far=render_params.far,
perturb=self.training,
mask=render_params.mask,
)
else:
xyz = xyz.unsqueeze(1) # expand to make shape [B, 1, n_query_points, 3]
viewdirs = None
# coarse prediction
rgb_coarse, alpha_coarse = self.query_network(xyz, local_latents, viewdirs)
if return_alpha_only:
return alpha_coarse
if self.hierarchical_sampling:
_, _, _, weights, _, occupancy_prior = volume_render_radiance_field(
rgb=rgb_coarse,
occupancy=alpha_coarse,
depth_values=z_vals,
ray_directions=rd,
radiance_field_noise_std=render_params.alpha_noise_std,
alpha_activation=self.alpha_activation,
activate_rgb=not self.feature_nerf,
density_bias=self.density_bias,
)
z_vals_fine = self.importance_sampling(z_vals, weights, render_params.samples_per_ray)
xyz = ro[..., None, :] + rd[..., None, :] * z_vals_fine[..., :, None]
viewdirs = viewdirs[:, :, 0:1].expand_as(xyz)
rgb_fine, alpha_fine = self.query_network(xyz, local_latents, viewdirs)
rgb = torch.cat([rgb_coarse, rgb_fine], dim=-2)
alpha = torch.cat([alpha_coarse, alpha_fine], dim=-1)
z_vals = torch.cat([z_vals, z_vals_fine], dim=-1)
_, indices = torch.sort(z_vals, dim=-1)
z_vals = torch.gather(z_vals, -1, indices)
rgb_indices = repeat(indices, 'b n_rays n_samples -> b n_rays n_samples d', d=rgb.shape[-1])
rgb = torch.gather(rgb, -2, rgb_indices)
alpha = torch.gather(alpha, -1, indices)
else:
rgb, alpha = rgb_coarse, alpha_coarse
z_vals = z_vals
rgb, _, _, _, depth, occupancy_prior = volume_render_radiance_field(
rgb=rgb,
occupancy=alpha,
depth_values=z_vals,
ray_directions=rd,
radiance_field_noise_std=render_params.alpha_noise_std,
alpha_activation=self.alpha_activation,
activate_rgb=not self.feature_nerf,
density_bias=self.density_bias,
)
out = {
'rgb': rgb,
'depth': depth,
'Rt': render_params.Rt,
'K': render_params.K,
'local_latents': local_latents,
'occupancy_prior': occupancy_prior,
}
return out