def sample_local_latents()

in models/generator.py [0:0]


    def sample_local_latents(self, local_latents, xyz):
        if local_latents.ndim == 4:
            B, local_z_dim, H, W = local_latents.shape
            # take only x and z coordinates, since our latent codes are in a 2D grid (no y dimension)
            # for the purposes of grid_sample we treat H*W as the H dimension and samples_per_ray as the W dimension
            xyz = xyz[:, :, :, [0, 2]]  # [B, H * W, samples_per_ray, 2]
        elif local_latents.ndim == 5:
            B, local_z_dim, D, H, W = local_latents.shape

            B, HW, samples_per_ray, _ = xyz.shape
            H = int(np.sqrt(HW))
            xyz = xyz.view(B, H, H, samples_per_ray, 3)

        samples_per_ray = xyz.shape[2]

        # all samples get the most detailed latent codes
        sampled_local_latents = nn.functional.grid_sample(
            input=local_latents,
            grid=xyz,
            mode='bilinear',  # bilinear mode will use trilinear interpolation if input is 5D
            align_corners=False,
            padding_mode="zeros",
        )
        # output is shape [B, local_z_dim, H * W, samples_per_ray]

        if local_latents.ndim == 4:
            # put channel dimension at end: [B, H * W, samples_per_ray, local_z_dim]
            sampled_local_latents = sampled_local_latents.permute(0, 2, 3, 1)
        elif local_latents.ndim == 5:
            sampled_local_latents = sampled_local_latents.permute(0, 2, 3, 4, 1)

        # merge everything else into batch dim: [B * H * W * samples_per_ray, local_z_dim]
        sampled_local_latents = sampled_local_latents.reshape(-1, local_z_dim)

        return sampled_local_latents, local_latents