def _prepare_next_batch_impl()

in long_term/pose_network_long_term.py [0:0]


    def _prepare_next_batch_impl(self, batch_size, dataset, target_length, sequences):
        super()._prepare_next_batch_impl(batch_size, dataset, target_length, sequences)
        
        assert dataset.skeleton() == self.skeleton
        nj = self.skeleton.num_joints()
        
        # The memory layout of the batches is: rotations or positions | translations | controls
        buffer_rot = np.zeros((batch_size, self.prefix_length+target_length,
                               nj*4 + self.translations_size + self.controls_size), dtype='float32')
        buffer_pos = np.zeros((batch_size, target_length, nj*3 + self.translations_size), dtype='float32')
        
        probs = []
        for i, (subject, action) in enumerate(sequences):
            probs.append(dataset[subject][action]['rotations'].shape[0])
        probs = np.array(probs)/np.sum(probs)
        
        pseudo_passes = (len(sequences) + batch_size - 1) // batch_size # Round in excess
        for p in range(pseudo_passes):
            idxs = np.random.choice(len(sequences), size=batch_size, replace=True, p=probs)
            for i, (subject, action) in enumerate(np.array(sequences)[idxs]):
                # Pick a random chunk
                full_seq_length = dataset[subject][action]['rotations'].shape[0]
                max_index = full_seq_length - self.prefix_length - target_length + 1
                start_idx = np.random.randint(0, max_index)
                mid_idx = start_idx + self.prefix_length
                end_idx = start_idx + self.prefix_length + target_length

                buffer_rot[i, :, :nj*4] = dataset[subject][action]['rotations'][start_idx:end_idx].reshape( \
                                              self.prefix_length+target_length, -1)
                buffer_rot[i, :, nj*4:] = dataset[subject][action]['extra_features'][start_idx:end_idx]

                buffer_pos[i, :, :nj*3] = dataset[subject][action]['positions_local'][mid_idx:end_idx].reshape( \
                                              target_length, -1)
                buffer_pos[i, :, nj*3:] = dataset[subject][action]['extra_features'][mid_idx:end_idx, :self.translations_size]
                        
            # Perform data augmentation
            buffer_rot[:], buffer_pos[:] = self._rotate_batch(buffer_rot, buffer_pos)

            yield buffer_rot, buffer_pos