def backward()

in empose/nn/models.py [0:0]


    def backward(self, batch: ABatch, model_out, writer=None, global_step=None):
        """The backward pass."""
        pose_hat, root_ori_hat, shape_hat = model_out['pose_hat'], model_out['root_ori_hat'], model_out['shape_hat']

        n, f = batch.batch_size, batch.seq_length
        mask = mask_from_seq_lengths(batch.seq_lengths).to(dtype=pose_hat.dtype)

        p_body = batch.poses_body.reshape(n, f, -1, 3)
        p_root = batch.poses_root.reshape(n, f, -1, 3)

        # Joint-wise squared L2 norm for rotations.
        pose_loss = normal_mse(p_body, pose_hat.reshape(n, f, -1, 3),
                               batch.seq_lengths, batch.marker_masks)
        root_pose_loss = normal_mse(p_root, root_ori_hat.reshape(n, f, -1, 3),
                                    batch.seq_lengths, batch.marker_masks)

        if self.estimate_shape:
            shape_loss = padded_loss(batch.shapes.unsqueeze(1).repeat((1, shape_hat.shape[1], 1)),
                                     shape_hat, self.shape_loss, batch.seq_lengths)
        else:
            shape_loss = torch.zeros(1).to(device=C.DEVICE)

        if self.do_fk:
            joints_gt = batch.joints_gt.reshape(batch.batch_size, batch.seq_length, -1, 3)
            joints_hat = model_out['joints_hat'].reshape(batch.batch_size, batch.seq_length, -1, 3)
            fk_loss = reconstruction_loss(joints_gt, joints_hat, batch.seq_lengths, batch.marker_masks)
        else:
            fk_loss = torch.zeros(1).to(device=C.DEVICE)

        total_loss = pose_loss + root_pose_loss + shape_loss + self.fk_loss_weight * fk_loss

        loss_vals = {'pose': pose_loss.cpu().item(),
                     'root_pose': root_pose_loss.cpu().item(),
                     'shape': shape_loss.cpu().item(),
                     'fk': fk_loss.cpu().item(),
                     'total_loss': total_loss.cpu().item()}

        if writer is not None:
            self.log_loss_vals(loss_vals, writer, global_step)

        if self.training:
            total_loss.backward()

        return total_loss, loss_vals