in long_term/pose_network_long_term.py [0:0]
def _prepare_next_batch_impl(self, batch_size, dataset, target_length, sequences):
super()._prepare_next_batch_impl(batch_size, dataset, target_length, sequences)
assert dataset.skeleton() == self.skeleton
nj = self.skeleton.num_joints()
# The memory layout of the batches is: rotations or positions | translations | controls
buffer_rot = np.zeros((batch_size, self.prefix_length+target_length,
nj*4 + self.translations_size + self.controls_size), dtype='float32')
buffer_pos = np.zeros((batch_size, target_length, nj*3 + self.translations_size), dtype='float32')
probs = []
for i, (subject, action) in enumerate(sequences):
probs.append(dataset[subject][action]['rotations'].shape[0])
probs = np.array(probs)/np.sum(probs)
pseudo_passes = (len(sequences) + batch_size - 1) // batch_size # Round in excess
for p in range(pseudo_passes):
idxs = np.random.choice(len(sequences), size=batch_size, replace=True, p=probs)
for i, (subject, action) in enumerate(np.array(sequences)[idxs]):
# Pick a random chunk
full_seq_length = dataset[subject][action]['rotations'].shape[0]
max_index = full_seq_length - self.prefix_length - target_length + 1
start_idx = np.random.randint(0, max_index)
mid_idx = start_idx + self.prefix_length
end_idx = start_idx + self.prefix_length + target_length
buffer_rot[i, :, :nj*4] = dataset[subject][action]['rotations'][start_idx:end_idx].reshape( \
self.prefix_length+target_length, -1)
buffer_rot[i, :, nj*4:] = dataset[subject][action]['extra_features'][start_idx:end_idx]
buffer_pos[i, :, :nj*3] = dataset[subject][action]['positions_local'][mid_idx:end_idx].reshape( \
target_length, -1)
buffer_pos[i, :, nj*3:] = dataset[subject][action]['extra_features'][mid_idx:end_idx, :self.translations_size]
# Perform data augmentation
buffer_rot[:], buffer_pos[:] = self._rotate_batch(buffer_rot, buffer_pos)
yield buffer_rot, buffer_pos