in models/generator.py [0:0]
def sample_local_latents(self, local_latents, xyz):
if local_latents.ndim == 4:
B, local_z_dim, H, W = local_latents.shape
# take only x and z coordinates, since our latent codes are in a 2D grid (no y dimension)
# for the purposes of grid_sample we treat H*W as the H dimension and samples_per_ray as the W dimension
xyz = xyz[:, :, :, [0, 2]] # [B, H * W, samples_per_ray, 2]
elif local_latents.ndim == 5:
B, local_z_dim, D, H, W = local_latents.shape
B, HW, samples_per_ray, _ = xyz.shape
H = int(np.sqrt(HW))
xyz = xyz.view(B, H, H, samples_per_ray, 3)
samples_per_ray = xyz.shape[2]
# all samples get the most detailed latent codes
sampled_local_latents = nn.functional.grid_sample(
input=local_latents,
grid=xyz,
mode='bilinear', # bilinear mode will use trilinear interpolation if input is 5D
align_corners=False,
padding_mode="zeros",
)
# output is shape [B, local_z_dim, H * W, samples_per_ray]
if local_latents.ndim == 4:
# put channel dimension at end: [B, H * W, samples_per_ray, local_z_dim]
sampled_local_latents = sampled_local_latents.permute(0, 2, 3, 1)
elif local_latents.ndim == 5:
sampled_local_latents = sampled_local_latents.permute(0, 2, 3, 4, 1)
# merge everything else into batch dim: [B * H * W * samples_per_ray, local_z_dim]
sampled_local_latents = sampled_local_latents.reshape(-1, local_z_dim)
return sampled_local_latents, local_latents