def backward()

in empose/nn/models.py [0:0]


    def backward(self, batch: ABatch, model_out, writer=None, global_step=None):
        """The backward pass."""
        batch_size, seq_length = batch.batch_size, batch.seq_length
        inputs_ = self.prepare_inputs(batch.get_inputs())
        marker_pos_in = inputs_[:, :, self.pos_d_start:self.pos_d_end]
        marker_ori_in = inputs_[:, :, self.ori_d_start:self.ori_d_end]
        markers_in = marker_pos_in.reshape((batch_size, seq_length, -1, 3))
        markers_ori_in = marker_ori_in.reshape((batch_size, seq_length, -1, 9))

        reconstruction_loss_total = torch.zeros(1).to(device=C.DEVICE)
        shape_loss_total = torch.zeros(1).to(device=C.DEVICE)
        pose_loss_total = torch.zeros(1).to(device=C.DEVICE)
        fk_loss_total = torch.zeros(1).to(device=C.DEVICE)

        for i in range(len(self.pose_hat_history)):
            pose_hat = self.pose_hat_history[i].reshape((batch_size, seq_length, -1))
            shape_hat = self.shape_hat_history[i].reshape((batch_size, seq_length, -1))

            pose_loss_total += padded_loss(torch.cat([batch.poses_root, batch.poses_body], dim=-1),
                                           pose_hat, self.smpl_loss, batch.seq_lengths)
            shape_loss_total += padded_loss(batch.shapes.unsqueeze(1).repeat((1, seq_length, 1)),
                                            shape_hat, self.smpl_loss, batch.seq_lengths)

            if self.do_fk:
                joints_gt = batch.joints_gt.reshape(batch.batch_size, batch.seq_length, -1, 3)
                joints_hat = model_out['joints_hat'].reshape(batch.batch_size, batch.seq_length, -1, 3)
                fk_loss_total += reconstruction_loss(joints_gt, joints_hat, batch.seq_lengths, batch.marker_masks)

            if self.config.use_marker_pos:
                markers_hat = self.markers_hat_history[i].reshape((batch_size, seq_length, -1, 3))
                reconstruction_loss_total += reconstruction_loss(markers_in, markers_hat[:, :, self.marker_idxs],
                                                               batch.seq_lengths, batch.marker_masks)

            if self.config.use_marker_ori:
                markers_ori_hat = self.markers_ori_hat_history[i].reshape((batch_size, seq_length, -1, 9))
                reconstruction_loss_total += reconstruction_loss(markers_ori_in, markers_ori_hat[:, :, self.marker_idxs],
                                                               batch.seq_lengths, batch.marker_masks)

        total_loss = self.pose_weight * pose_loss_total + self.fk_loss_weight * fk_loss_total
        total_loss += self.shape_weight * shape_loss_total + self.r_weight * reconstruction_loss_total
        total_loss = total_loss / len(self.pose_hat_history)

        loss_vals = {'pose': pose_loss_total.cpu().item() / len(self.pose_hat_history),
                     'shape': shape_loss_total.cpu().item() / len(self.pose_hat_history),
                     'reconstruction': reconstruction_loss_total.cpu().item() / len(self.pose_hat_history),
                     'fk': fk_loss_total.cpu().item() / len(self.joints_hat_history),
                     'total_loss': total_loss.cpu().item()}

        if writer is not None:
            self.log_loss_vals(loss_vals, writer, global_step)

        if self.training:
            total_loss.backward()

        return total_loss, loss_vals