def generate()

in models/gsn.py [0:0]


    def generate(self, z, camera_params):
        # camera_params should be a dict with Rt and K (if Rt is not present it will be sampled)

        nerf_out_res = self.generator_config.params.nerf_out_res
        samples_per_ray = self.generator_config.params.samples_per_ray

        # use EMA weights if in eval mode
        decoder = self.decoder if self.training else self.decoder_ema
        generator = self.generator if self.training else self.generator_ema
        texture_net = self.texture_net if self.training else self.texture_net_ema

        # map 1D latent code z to 2D latent code w
        w = decoder(z=z)

        if 'Rt' not in camera_params.keys():
            Rt = self.trajectory_sampler.sample_trajectories(self.generator, w)
            camera_params['Rt'] = Rt

        # duplicate latent codes along the trajectory dimension
        T = camera_params['Rt'].shape[1]  # trajectory length
        w = repeat(w, 'b c h w -> b t c h w', t=T)
        w = rearrange(w, 'b t c h w -> (b t) c h w')

        if self.patch_size is None:
            # compute full image in one pass
            indices_chunks = [None]
        elif nerf_out_res <= self.patch_size:
            indices_chunks = [None]
        elif nerf_out_res > self.patch_size:
            # break the whole image into manageable pieces, then compute each of those separately
            indices = torch.arange(nerf_out_res ** 2, device=z.device)
            indices_chunks = torch.chunk(indices, chunks=int(nerf_out_res ** 2 / self.patch_size ** 2))

        rgb, depth = [], []
        for indices in indices_chunks:
            render_params = RenderParams(
                Rt=rearrange(camera_params['Rt'], 'b t h w -> (b t) h w').clone(),
                K=rearrange(camera_params['K'], 'b t h w -> (b t) h w').clone(),
                samples_per_ray=samples_per_ray,
                near=self.generator_config.params.near,
                far=self.generator_config.params.far,
                alpha_noise_std=self.generator_config.params.alpha_noise_std,
                nerf_out_res=nerf_out_res,
                mask=indices,
            )

            y_hat = generator(local_latents=w, render_params=render_params)
            rgb.append(y_hat['rgb'])  # shape [BT, HW, C]
            depth.append(y_hat['depth'])

        # combine image patches back into full images
        rgb = torch.cat(rgb, dim=1)
        depth = torch.cat(depth, dim=1)

        rgb = rearrange(rgb, 'b (h w) c -> b c h w', h=nerf_out_res, w=nerf_out_res)
        rgb = texture_net(rgb)
        rgb = rearrange(rgb, '(b t) c h w -> b t c h w', t=T)

        depth = rearrange(depth, '(b t) (h w) -> b t 1 h w', t=T, h=nerf_out_res, w=nerf_out_res)

        Rt = rearrange(y_hat['Rt'], '(b t) h w -> b t h w', t=T)
        K = rearrange(y_hat['K'], '(b t) h w -> b t h w', t=T)

        return rgb, depth, Rt, K