def predict()

in short_term/pose_network_short_term.py [0:0]


    def predict(self, prefix, target_length):
        """
        Predict a sequence using the given prefix.
        """
        assert target_length > 0
        
        with torch.no_grad():
            prefix = prefix.reshape(prefix.shape[1], -1, 4)
            prefix = qeuler_np(prefix, 'zyx')
            prefix = qfix(euler_to_quaternion(prefix, 'zyx'))
            inputs = torch.from_numpy(prefix.reshape(1, prefix.shape[0], -1).astype('float32'))
            
            if self.use_cuda:
                inputs = inputs.cuda()

            predicted, hidden = self.model(inputs)
            frames = [predicted]

            for i in range(1, target_length):
                predicted, hidden = self.model(predicted, hidden)
                frames.append(predicted)

            result = torch.cat(frames, dim=1)
            return result.view(result.shape[0], result.shape[1], -1, 4).cpu().numpy()