in models/gsn.py [0:0]
def generate(self, z, camera_params):
# camera_params should be a dict with Rt and K (if Rt is not present it will be sampled)
nerf_out_res = self.generator_config.params.nerf_out_res
samples_per_ray = self.generator_config.params.samples_per_ray
# use EMA weights if in eval mode
decoder = self.decoder if self.training else self.decoder_ema
generator = self.generator if self.training else self.generator_ema
texture_net = self.texture_net if self.training else self.texture_net_ema
# map 1D latent code z to 2D latent code w
w = decoder(z=z)
if 'Rt' not in camera_params.keys():
Rt = self.trajectory_sampler.sample_trajectories(self.generator, w)
camera_params['Rt'] = Rt
# duplicate latent codes along the trajectory dimension
T = camera_params['Rt'].shape[1] # trajectory length
w = repeat(w, 'b c h w -> b t c h w', t=T)
w = rearrange(w, 'b t c h w -> (b t) c h w')
if self.patch_size is None:
# compute full image in one pass
indices_chunks = [None]
elif nerf_out_res <= self.patch_size:
indices_chunks = [None]
elif nerf_out_res > self.patch_size:
# break the whole image into manageable pieces, then compute each of those separately
indices = torch.arange(nerf_out_res ** 2, device=z.device)
indices_chunks = torch.chunk(indices, chunks=int(nerf_out_res ** 2 / self.patch_size ** 2))
rgb, depth = [], []
for indices in indices_chunks:
render_params = RenderParams(
Rt=rearrange(camera_params['Rt'], 'b t h w -> (b t) h w').clone(),
K=rearrange(camera_params['K'], 'b t h w -> (b t) h w').clone(),
samples_per_ray=samples_per_ray,
near=self.generator_config.params.near,
far=self.generator_config.params.far,
alpha_noise_std=self.generator_config.params.alpha_noise_std,
nerf_out_res=nerf_out_res,
mask=indices,
)
y_hat = generator(local_latents=w, render_params=render_params)
rgb.append(y_hat['rgb']) # shape [BT, HW, C]
depth.append(y_hat['depth'])
# combine image patches back into full images
rgb = torch.cat(rgb, dim=1)
depth = torch.cat(depth, dim=1)
rgb = rearrange(rgb, 'b (h w) c -> b c h w', h=nerf_out_res, w=nerf_out_res)
rgb = texture_net(rgb)
rgb = rearrange(rgb, '(b t) c h w -> b t c h w', t=T)
depth = rearrange(depth, '(b t) (h w) -> b t 1 h w', t=T, h=nerf_out_res, w=nerf_out_res)
Rt = rearrange(y_hat['Rt'], '(b t) h w -> b t h w', t=T)
K = rearrange(y_hat['K'], '(b t) h w -> b t h w', t=T)
return rgb, depth, Rt, K