in empose/nn/models.py [0:0]
def forward(self, batch: ABatch, window_size=None, is_new_sequence=True):
# We need to accumulate gradients for the reconstruction error.
torch.set_grad_enabled(True)
if self.rnn_init:
if is_new_sequence:
self.rnn.final_state = None
self.rnn.init_state = self.rnn.final_state
pose_hat_history = []
shape_hat_history = []
joints_hat_history = []
markers_hat_history = []
markers_ori_hat_history = []
all_model_out = {'pose_hat': [], 'root_ori_hat': [], 'shape_hat': [], 'joints_hat': []}
for batch_inputs in self.window_generator(batch, window_size=window_size):
inputs_ = self.prepare_inputs(batch_inputs)
dof = inputs_.shape[-1]
batch_size, seq_length = inputs_.shape[0], inputs_.shape[1]
offset_r = batch_inputs['offset_r'] # (N, M, 3, 3)
offset_t = batch_inputs['offset_t'] # (N, M, 3)
offset_r_flat = offset_r.unsqueeze(1).repeat(1, seq_length, 1, 1, 1).reshape(batch_size*seq_length, -1, 3, 3)
offset_t_flat = offset_t.unsqueeze(1).repeat(1, seq_length, 1, 1).reshape(batch_size*seq_length, -1, 3)
if self.rnn_init:
self.rnn.init_state = self.rnn.final_state
lstm_out = self.rnn(inputs_, batch_inputs['seq_lengths'])
pose_hat = self.pose_net_init(lstm_out).reshape((batch_size * seq_length, -1))
shape_hat = self.shape_net_init(lstm_out).reshape((batch_size * seq_length, -1))
# Flatten everything.
inputs_flat = inputs_.reshape((-1, dof))
else:
# Flatten everything.
inputs_flat = inputs_.reshape((-1, dof))
# Get initial estimate.
pose_hat = self.pose_net_init(inputs_flat)
shape_hat = self.shape_net_init(inputs_flat)
# We only want one shape per sequence, so for now average the results and pad it again.
def _to_single_shape(shapes):
s = shapes.reshape(batch_size, seq_length, -1)
s = torch.mean(s, dim=1, keepdim=True)
return s.repeat((1, seq_length, 1)).reshape(seq_length * batch_size, -1)
if self.shape_avg:
shape_hat = _to_single_shape(shape_hat)
marker_pos_hat, marker_ori_hat, joints_hat = self.get_estimated_real_markers(
pose_hat, shape_hat, offset_r_flat, offset_t_flat, self.vertex_ids)
# Keep track of history.
pose_hat_history.append([pose_hat])
shape_hat_history.append([shape_hat])
joints_hat_history.append([joints_hat])
markers_hat_history.append([marker_pos_hat])
markers_ori_hat_history.append([marker_ori_hat])
# Iterative Error Feedback.
for i in range(self.N):
input_params = [inputs_flat,
pose_hat_history[-1][-1].clone().detach(),
shape_hat_history[-1][-1].clone().detach()]
if self.use_gradient:
pose_hat_history[-1][-1].retain_grad()
shape_hat_history[-1][-1].retain_grad()
joints_hat_history[-1][-1].retain_grad()
markers_hat_history[-1][-1].retain_grad()
markers_ori_hat_history[-1][-1].retain_grad()
reconstruction_error = torch.zeros([1]).to(dtype=inputs_.dtype, device=inputs_.device)
if self.config.use_marker_pos:
marker_pos_in = inputs_flat[:, self.pos_d_start:self.pos_d_end]
reconstruction_error += reconstruction_loss(
marker_pos_in.reshape(batch_size, seq_length, -1, 3),
markers_hat_history[-1][-1].reshape(batch_size, seq_length, -1, 3)[:, :, self.marker_idxs],
batch_inputs['seq_lengths'], batch_inputs['marker_masks'])
if self.config.use_marker_ori:
marker_ori_in = inputs_flat[:, self.ori_d_start:self.ori_d_end]
reconstruction_error += reconstruction_loss(
marker_ori_in.reshape(batch_size, seq_length, -1, 9),
markers_ori_hat_history[-1][-1].reshape(batch_size, seq_length, -1, 9)[:, :, self.marker_idxs],
batch_inputs['seq_lengths'], batch_inputs['marker_masks'])
reconstruction_error.backward(retain_graph=True)
pose_hat_grad = pose_hat_history[-1][-1].grad.clone().detach() * batch_size * seq_length
shape_hat_grad = shape_hat_history[-1][-1].grad.clone().detach() * batch_size * seq_length
input_params.append(pose_hat_grad)
input_params.append(shape_hat_grad)
inputs_ = torch.cat(input_params, dim=-1)
pose_hat_delta = self.pose_net_iter(inputs_)
shape_hat_delta = self.shape_net_iter(inputs_)
if self.shape_avg:
shape_hat_delta = _to_single_shape(shape_hat_delta)
pose_hat = pose_hat_history[-1][-1] + pose_hat_delta * self.step_size
shape_hat = shape_hat_history[-1][-1] + shape_hat_delta * self.step_size
marker_pos_hat, marker_ori_hat, joints_hat = self.get_estimated_real_markers(
pose_hat, shape_hat, offset_r_flat, offset_t_flat, self.vertex_ids)
pose_hat_history[-1].append(pose_hat)
shape_hat_history[-1].append(shape_hat)
joints_hat_history[-1].append(joints_hat)
markers_hat_history[-1].append(marker_pos_hat)
markers_ori_hat_history[-1].append(marker_ori_hat)
pose_hat_final = pose_hat_history[-1][-1].reshape((batch_size, seq_length, -1))
shape_hat_final = shape_hat_history[-1][-1].reshape((batch_size, seq_length, -1))
joints_hat_final = joints_hat_history[-1][-1].reshape((batch_size, seq_length, -1))
all_model_out['pose_hat'].append(pose_hat_final[:, :, 3:])
all_model_out['root_ori_hat'].append(pose_hat_final[:, :, :3])
all_model_out['shape_hat'].append(shape_hat_final)
all_model_out['joints_hat'].append(joints_hat_final)
# History is kept in nested list of size (n_windows, n_history), merge to list of size (n_history, ).
def _reshape(in_, out_):
for h in range(self.N + 1):
tmp = []
for k in range(len(in_)):
dof = in_[k][h].shape[-1]
tmp.append(in_[k][h].reshape((batch.batch_size, -1, dof)))
out_.append(torch.cat(tmp, dim=1))
self.pose_hat_history = []
self.shape_hat_history = []
self.joints_hat_history = []
self.markers_hat_history = []
self.markers_ori_hat_history = []
_reshape(pose_hat_history, self.pose_hat_history)
_reshape(shape_hat_history, self.shape_hat_history)
_reshape(joints_hat_history, self.joints_hat_history)
_reshape(markers_hat_history, self.markers_hat_history)
_reshape(markers_ori_hat_history, self.markers_ori_hat_history)
model_out = {k: torch.cat(all_model_out[k], dim=1) for k in all_model_out}
return model_out