def sample_trajectories()

in models/model_utils.py [0:0]


    def sample_trajectories(self, generator, local_latents):
        """Return trajectories that best traverse a given scene.

        Input:
        -----
        generator: SceneGenerator
            Generator object to be evaluated for occupancy.
        local_latents: torch.Tensor
            Local latent codes of shape [B, local_z_dim, H, W] corresponding to the scenes that will be evaluted.

        Return:
        ------
        Rts: torch.Tensor
            Trajectories of camera extrinsic matrices of shape [B, seq_len, 4, 4].

        """
        B = local_latents.shape[0]

        real_Rts = self.real_Rts.clone()
        if self.jitter_range:
            n_trajectories, seq_len, _, _ = real_Rts.shape

            jitter = torch.rand(size=(n_trajectories, 1, 3), device=real_Rts.device, requires_grad=False)
            jitter = (jitter * 2) - 1  # normalize to [-1, 1]
            jitter = jitter * self.jitter_range
            jitter[:, :, 1] = jitter[:, :, 1] * 0  # no jitter on the y axis

            camera_pose = real_Rts.inverse()
            camera_pose[:, :, :3, 3] = camera_pose[:, :, :3, 3] + jitter
            trajectories = camera_pose[:, :, :3, 3]
            real_Rts = camera_pose.inverse()
        else:
            trajectories = self.real_trajectories

        if self.mode == 'sample':
            occupancy = self.get_occupancy(generator=generator, local_latents=local_latents, trajectories=trajectories)

            # randomly choose 1k trajectories to sample from
            n_subsamples = min(real_Rts.shape[0], 1000)
            subset_indices = torch.multinomial(
                torch.ones(real_Rts.shape[0]), num_samples=n_subsamples, replacement=False
            )

            sample_weights = nn.functional.softmin(occupancy[:, subset_indices] + 1e-8, dim=-1)
            nans = torch.isnan(sample_weights)
            sample_weights[nans] = 1 / 1000
            selected_indices = torch.multinomial(sample_weights, num_samples=1, replacement=False).squeeze(1)

            Rts = real_Rts[subset_indices][selected_indices]

        elif self.mode == 'bin':
            occupancy = self.get_occupancy(generator=generator, local_latents=local_latents, trajectories=trajectories)

            # shuffle trajectories so that we don't always select the first completely unoccupied trajectory
            self.shuffle_trajectories_in_bins()

            Rts = []
            for i in range(len(local_latents)):
                selected_bin = self.bin_indices[np.random.choice(a=self.num_bins)]
                occupancies = occupancy[i, selected_bin]
                most_empty_idx = torch.argmin(occupancies)
                most_empty_idx = selected_bin[most_empty_idx]
                Rts.append(real_Rts[most_empty_idx])
            Rts = torch.stack(Rts, dim=0)

        elif self.mode == 'random':
            weight = torch.ones(real_Rts.shape[0])
            selected_indices = torch.multinomial(weight, num_samples=B, replacement=False)
            Rts = real_Rts[selected_indices]

        Rts = Rts.to(local_latents.device)

        return Rts