in short_term/pose_network_short_term.py [0:0]
def predict(self, prefix, target_length):
"""
Predict a sequence using the given prefix.
"""
assert target_length > 0
with torch.no_grad():
prefix = prefix.reshape(prefix.shape[1], -1, 4)
prefix = qeuler_np(prefix, 'zyx')
prefix = qfix(euler_to_quaternion(prefix, 'zyx'))
inputs = torch.from_numpy(prefix.reshape(1, prefix.shape[0], -1).astype('float32'))
if self.use_cuda:
inputs = inputs.cuda()
predicted, hidden = self.model(inputs)
frames = [predicted]
for i in range(1, target_length):
predicted, hidden = self.model(predicted, hidden)
frames.append(predicted)
result = torch.cat(frames, dim=1)
return result.view(result.shape[0], result.shape[1], -1, 4).cpu().numpy()