in empose/data/transforms.py [0:0]
def __call__(self, batch: ABatch):
n, f = batch.batch_size, batch.seq_length
vs = batch.vertices.reshape(n * f, -1, 3)
markers, marker_oris, marker_normals = self.virtual_helper.get_virtual_pos_and_rot(vs, self.vertex_ids)
# Store the local marker positions and orientations (certain models might have them as targets).
batch.marker_pos_vertex = markers.clone().detach().reshape(n, f, -1)
batch.marker_ori_vertex = marker_oris.clone().detach().reshape(n, f, -1)
batch.marker_normal_vertex = marker_normals.clone().detach().reshape(n, f, -1)
# Apply offsets, may be with noise.
s_idxs = self.offset_rng.randint(0, self.n_offsets, n)
offset_means = torch.from_numpy(self.offset_means[s_idxs]).to(dtype=markers.dtype)
local_offsets = offset_means.clone().unsqueeze(1).repeat(1, f, 1, 1)
s_idxs = torch.from_numpy(s_idxs).to(dtype=torch.long)
if self.randomize:
if self.noise_level == 0:
offset_noise = self.normal_dists.sample((n, )) # (N, N_OFFSETS, M, 3)
offset_noise = offset_noise[torch.arange(n), s_idxs] # (N, M, 3)
local_offsets = offset_noise.unsqueeze(1).repeat(1, f, 1, 1)
elif self.noise_level == 1:
offset_noise = self.normal_dists.sample((n, f)) # (N, F, N_OFFSETS, M, 3)
s = s_idxs.unsqueeze(-1).repeat(1, f).reshape(-1)
offset_noise = offset_noise.reshape((n*f, self.n_offsets, -1, 3))[torch.arange(n*f), s]
local_offsets = offset_noise.reshape((n, f, -1, 3))
elif self.noise_level == 2 or self.noise_level == 3:
local_offsets = torch.zeros_like(local_offsets)
else:
raise ValueError("Unknown noise level {}".format(self.noise_level))
local_offsets = local_offsets.to(device=markers.device)
# Apply offsets to marker position.
ms = markers.reshape((n, f, -1, 3))
ori_synth = marker_oris.reshape((n, f, -1, 3, 3))
markers_new = ms + torch.matmul(ori_synth, local_offsets.unsqueeze(-1)).squeeze()
batch.marker_pos_synth = markers_new.reshape((n, f, -1))
# Apply offset to marker orientation.
if isinstance(self.r, np.ndarray):
self.r = torch.from_numpy(self.r).to(dtype=local_offsets.dtype, device=local_offsets.device)
r = self.r[s_idxs].unsqueeze(1).repeat(1, f, 1, 1, 1)
if self.randomize:
if self.noise_level == 3:
r = torch.zeros_like(r)
r[:, :, :, 0, 0] = 1.0
r[:, :, :, 1, 1] = 1.0
r[:, :, :, 2, 2] = 1.0
ori_synth = torch.matmul(ori_synth, r)
marker_normals = ori_synth[..., 2]
# Order is already correct since we load the vertex IDs from the offset file.
batch.marker_pos_synth = markers_new.reshape(n, f, -1)
batch.marker_ori_synth = ori_synth.reshape(n, f, -1)
batch.marker_normal_synth = marker_normals.reshape(n, f, -1)
# Store the information to revert synthetic offsets. We always take the mean of the offsets for this, since
# this is the information we'll have during test as well.
batch.offset_t_augmented = offset_means.clone().detach().to(device=markers.device)
batch.offset_r_augmented = r[:, 0].clone().detach().to(device=markers.device) # Take first frame.
return batch