in models/model_utils.py [0:0]
def sample_trajectories(self, generator, local_latents):
"""Return trajectories that best traverse a given scene.
Input:
-----
generator: SceneGenerator
Generator object to be evaluated for occupancy.
local_latents: torch.Tensor
Local latent codes of shape [B, local_z_dim, H, W] corresponding to the scenes that will be evaluted.
Return:
------
Rts: torch.Tensor
Trajectories of camera extrinsic matrices of shape [B, seq_len, 4, 4].
"""
B = local_latents.shape[0]
real_Rts = self.real_Rts.clone()
if self.jitter_range:
n_trajectories, seq_len, _, _ = real_Rts.shape
jitter = torch.rand(size=(n_trajectories, 1, 3), device=real_Rts.device, requires_grad=False)
jitter = (jitter * 2) - 1 # normalize to [-1, 1]
jitter = jitter * self.jitter_range
jitter[:, :, 1] = jitter[:, :, 1] * 0 # no jitter on the y axis
camera_pose = real_Rts.inverse()
camera_pose[:, :, :3, 3] = camera_pose[:, :, :3, 3] + jitter
trajectories = camera_pose[:, :, :3, 3]
real_Rts = camera_pose.inverse()
else:
trajectories = self.real_trajectories
if self.mode == 'sample':
occupancy = self.get_occupancy(generator=generator, local_latents=local_latents, trajectories=trajectories)
# randomly choose 1k trajectories to sample from
n_subsamples = min(real_Rts.shape[0], 1000)
subset_indices = torch.multinomial(
torch.ones(real_Rts.shape[0]), num_samples=n_subsamples, replacement=False
)
sample_weights = nn.functional.softmin(occupancy[:, subset_indices] + 1e-8, dim=-1)
nans = torch.isnan(sample_weights)
sample_weights[nans] = 1 / 1000
selected_indices = torch.multinomial(sample_weights, num_samples=1, replacement=False).squeeze(1)
Rts = real_Rts[subset_indices][selected_indices]
elif self.mode == 'bin':
occupancy = self.get_occupancy(generator=generator, local_latents=local_latents, trajectories=trajectories)
# shuffle trajectories so that we don't always select the first completely unoccupied trajectory
self.shuffle_trajectories_in_bins()
Rts = []
for i in range(len(local_latents)):
selected_bin = self.bin_indices[np.random.choice(a=self.num_bins)]
occupancies = occupancy[i, selected_bin]
most_empty_idx = torch.argmin(occupancies)
most_empty_idx = selected_bin[most_empty_idx]
Rts.append(real_Rts[most_empty_idx])
Rts = torch.stack(Rts, dim=0)
elif self.mode == 'random':
weight = torch.ones(real_Rts.shape[0])
selected_indices = torch.multinomial(weight, num_samples=B, replacement=False)
Rts = real_Rts[selected_indices]
Rts = Rts.to(local_latents.device)
return Rts