in empose/nn/models.py [0:0]
def backward(self, batch: ABatch, model_out, writer=None, global_step=None):
"""The backward pass."""
batch_size, seq_length = batch.batch_size, batch.seq_length
inputs_ = self.prepare_inputs(batch.get_inputs())
marker_pos_in = inputs_[:, :, self.pos_d_start:self.pos_d_end]
marker_ori_in = inputs_[:, :, self.ori_d_start:self.ori_d_end]
markers_in = marker_pos_in.reshape((batch_size, seq_length, -1, 3))
markers_ori_in = marker_ori_in.reshape((batch_size, seq_length, -1, 9))
reconstruction_loss_total = torch.zeros(1).to(device=C.DEVICE)
shape_loss_total = torch.zeros(1).to(device=C.DEVICE)
pose_loss_total = torch.zeros(1).to(device=C.DEVICE)
fk_loss_total = torch.zeros(1).to(device=C.DEVICE)
for i in range(len(self.pose_hat_history)):
pose_hat = self.pose_hat_history[i].reshape((batch_size, seq_length, -1))
shape_hat = self.shape_hat_history[i].reshape((batch_size, seq_length, -1))
pose_loss_total += padded_loss(torch.cat([batch.poses_root, batch.poses_body], dim=-1),
pose_hat, self.smpl_loss, batch.seq_lengths)
shape_loss_total += padded_loss(batch.shapes.unsqueeze(1).repeat((1, seq_length, 1)),
shape_hat, self.smpl_loss, batch.seq_lengths)
if self.do_fk:
joints_gt = batch.joints_gt.reshape(batch.batch_size, batch.seq_length, -1, 3)
joints_hat = model_out['joints_hat'].reshape(batch.batch_size, batch.seq_length, -1, 3)
fk_loss_total += reconstruction_loss(joints_gt, joints_hat, batch.seq_lengths, batch.marker_masks)
if self.config.use_marker_pos:
markers_hat = self.markers_hat_history[i].reshape((batch_size, seq_length, -1, 3))
reconstruction_loss_total += reconstruction_loss(markers_in, markers_hat[:, :, self.marker_idxs],
batch.seq_lengths, batch.marker_masks)
if self.config.use_marker_ori:
markers_ori_hat = self.markers_ori_hat_history[i].reshape((batch_size, seq_length, -1, 9))
reconstruction_loss_total += reconstruction_loss(markers_ori_in, markers_ori_hat[:, :, self.marker_idxs],
batch.seq_lengths, batch.marker_masks)
total_loss = self.pose_weight * pose_loss_total + self.fk_loss_weight * fk_loss_total
total_loss += self.shape_weight * shape_loss_total + self.r_weight * reconstruction_loss_total
total_loss = total_loss / len(self.pose_hat_history)
loss_vals = {'pose': pose_loss_total.cpu().item() / len(self.pose_hat_history),
'shape': shape_loss_total.cpu().item() / len(self.pose_hat_history),
'reconstruction': reconstruction_loss_total.cpu().item() / len(self.pose_hat_history),
'fk': fk_loss_total.cpu().item() / len(self.joints_hat_history),
'total_loss': total_loss.cpu().item()}
if writer is not None:
self.log_loss_vals(loss_vals, writer, global_step)
if self.training:
total_loss.backward()
return total_loss, loss_vals