in empose/nn/models.py [0:0]
def backward(self, batch: ABatch, model_out, writer=None, global_step=None):
"""The backward pass."""
pose_hat, root_ori_hat, shape_hat = model_out['pose_hat'], model_out['root_ori_hat'], model_out['shape_hat']
n, f = batch.batch_size, batch.seq_length
mask = mask_from_seq_lengths(batch.seq_lengths).to(dtype=pose_hat.dtype)
p_body = batch.poses_body.reshape(n, f, -1, 3)
p_root = batch.poses_root.reshape(n, f, -1, 3)
# Joint-wise squared L2 norm for rotations.
pose_loss = normal_mse(p_body, pose_hat.reshape(n, f, -1, 3),
batch.seq_lengths, batch.marker_masks)
root_pose_loss = normal_mse(p_root, root_ori_hat.reshape(n, f, -1, 3),
batch.seq_lengths, batch.marker_masks)
if self.estimate_shape:
shape_loss = padded_loss(batch.shapes.unsqueeze(1).repeat((1, shape_hat.shape[1], 1)),
shape_hat, self.shape_loss, batch.seq_lengths)
else:
shape_loss = torch.zeros(1).to(device=C.DEVICE)
if self.do_fk:
joints_gt = batch.joints_gt.reshape(batch.batch_size, batch.seq_length, -1, 3)
joints_hat = model_out['joints_hat'].reshape(batch.batch_size, batch.seq_length, -1, 3)
fk_loss = reconstruction_loss(joints_gt, joints_hat, batch.seq_lengths, batch.marker_masks)
else:
fk_loss = torch.zeros(1).to(device=C.DEVICE)
total_loss = pose_loss + root_pose_loss + shape_loss + self.fk_loss_weight * fk_loss
loss_vals = {'pose': pose_loss.cpu().item(),
'root_pose': root_pose_loss.cpu().item(),
'shape': shape_loss.cpu().item(),
'fk': fk_loss.cpu().item(),
'total_loss': total_loss.cpu().item()}
if writer is not None:
self.log_loss_vals(loss_vals, writer, global_step)
if self.training:
total_loss.backward()
return total_loss, loss_vals