in models/generator.py [0:0]
def query_network(self, xyz, local_latents, viewdirs):
if self.coordinate_scale is not None:
# this tries to get all input coordinates to lie within [-1, 1]
xyz = xyz / (self.coordinate_scale / 2)
B, n_samples, samples_per_ray, _ = xyz.shape # n_samples = H * W
sampled_local_latents, local_latents = self.sample_local_latents(local_latents, xyz=xyz)
if self.local_coordinates:
# map global coordinate space into local coordinate space (i.e. each grid cell has a [-1, 1] range)
preserve_y = local_latents.ndim == 4 # if latents are 2D, then keep the y coordinate global
xyz = self.get_local_coordinates(
global_coords=xyz, local_grid_length=self.global_feat_res, preserve_y=preserve_y
)
xyz = xyz.reshape(-1, 3)
viewdirs = viewdirs.reshape(-1, 3) if viewdirs is not None else None
rgb, alpha = self.local_generator(z=sampled_local_latents, coords=xyz, viewdirs=viewdirs)
# shape is [B, H*W, samples_per_ray, 3] if not masks, otherwise [B, n_rays, samples_per_ray, 3]
rgb = rgb.view(B, -1, samples_per_ray, self.out_dim) if rgb is not None else None
alpha = alpha.view(B, -1, samples_per_ray)
return rgb, alpha