def forward()

in empose/nn/models.py [0:0]


    def forward(self, batch: ABatch, window_size=None, is_new_sequence=True):
        # We need to accumulate gradients for the reconstruction error.
        torch.set_grad_enabled(True)

        if self.rnn_init:
            if is_new_sequence:
                self.rnn.final_state = None
            self.rnn.init_state = self.rnn.final_state

        pose_hat_history = []
        shape_hat_history = []
        joints_hat_history = []
        markers_hat_history = []
        markers_ori_hat_history = []
        all_model_out = {'pose_hat': [], 'root_ori_hat': [], 'shape_hat': [], 'joints_hat': []}

        for batch_inputs in self.window_generator(batch, window_size=window_size):
            inputs_ = self.prepare_inputs(batch_inputs)
            dof = inputs_.shape[-1]
            batch_size, seq_length = inputs_.shape[0], inputs_.shape[1]

            offset_r = batch_inputs['offset_r']  # (N, M, 3, 3)
            offset_t = batch_inputs['offset_t']  # (N, M, 3)
            offset_r_flat = offset_r.unsqueeze(1).repeat(1, seq_length, 1, 1, 1).reshape(batch_size*seq_length, -1, 3, 3)
            offset_t_flat = offset_t.unsqueeze(1).repeat(1, seq_length, 1, 1).reshape(batch_size*seq_length, -1, 3)

            if self.rnn_init:
                self.rnn.init_state = self.rnn.final_state
                lstm_out = self.rnn(inputs_, batch_inputs['seq_lengths'])
                pose_hat = self.pose_net_init(lstm_out).reshape((batch_size * seq_length, -1))
                shape_hat = self.shape_net_init(lstm_out).reshape((batch_size * seq_length, -1))

                # Flatten everything.
                inputs_flat = inputs_.reshape((-1, dof))

            else:
                # Flatten everything.
                inputs_flat = inputs_.reshape((-1, dof))

                # Get initial estimate.
                pose_hat = self.pose_net_init(inputs_flat)
                shape_hat = self.shape_net_init(inputs_flat)

            # We only want one shape per sequence, so for now average the results and pad it again.
            def _to_single_shape(shapes):
                s = shapes.reshape(batch_size, seq_length, -1)
                s = torch.mean(s, dim=1, keepdim=True)
                return s.repeat((1, seq_length, 1)).reshape(seq_length * batch_size, -1)

            if self.shape_avg:
                shape_hat = _to_single_shape(shape_hat)

            marker_pos_hat, marker_ori_hat, joints_hat = self.get_estimated_real_markers(
                pose_hat, shape_hat, offset_r_flat, offset_t_flat, self.vertex_ids)

            # Keep track of history.
            pose_hat_history.append([pose_hat])
            shape_hat_history.append([shape_hat])
            joints_hat_history.append([joints_hat])
            markers_hat_history.append([marker_pos_hat])
            markers_ori_hat_history.append([marker_ori_hat])

            # Iterative Error Feedback.
            for i in range(self.N):
                input_params = [inputs_flat,
                                pose_hat_history[-1][-1].clone().detach(),
                                shape_hat_history[-1][-1].clone().detach()]

                if self.use_gradient:
                    pose_hat_history[-1][-1].retain_grad()
                    shape_hat_history[-1][-1].retain_grad()
                    joints_hat_history[-1][-1].retain_grad()
                    markers_hat_history[-1][-1].retain_grad()
                    markers_ori_hat_history[-1][-1].retain_grad()

                    reconstruction_error = torch.zeros([1]).to(dtype=inputs_.dtype, device=inputs_.device)

                    if self.config.use_marker_pos:
                        marker_pos_in = inputs_flat[:, self.pos_d_start:self.pos_d_end]
                        reconstruction_error += reconstruction_loss(
                            marker_pos_in.reshape(batch_size, seq_length, -1, 3),
                            markers_hat_history[-1][-1].reshape(batch_size, seq_length, -1, 3)[:, :, self.marker_idxs],
                            batch_inputs['seq_lengths'], batch_inputs['marker_masks'])

                    if self.config.use_marker_ori:
                        marker_ori_in = inputs_flat[:, self.ori_d_start:self.ori_d_end]
                        reconstruction_error += reconstruction_loss(
                            marker_ori_in.reshape(batch_size, seq_length, -1, 9),
                            markers_ori_hat_history[-1][-1].reshape(batch_size, seq_length, -1, 9)[:, :, self.marker_idxs],
                            batch_inputs['seq_lengths'], batch_inputs['marker_masks'])

                    reconstruction_error.backward(retain_graph=True)

                    pose_hat_grad = pose_hat_history[-1][-1].grad.clone().detach() * batch_size * seq_length
                    shape_hat_grad = shape_hat_history[-1][-1].grad.clone().detach() * batch_size * seq_length

                    input_params.append(pose_hat_grad)
                    input_params.append(shape_hat_grad)

                inputs_ = torch.cat(input_params, dim=-1)

                pose_hat_delta = self.pose_net_iter(inputs_)
                shape_hat_delta = self.shape_net_iter(inputs_)
                if self.shape_avg:
                    shape_hat_delta = _to_single_shape(shape_hat_delta)

                pose_hat = pose_hat_history[-1][-1] + pose_hat_delta * self.step_size
                shape_hat = shape_hat_history[-1][-1] + shape_hat_delta * self.step_size
                marker_pos_hat, marker_ori_hat, joints_hat = self.get_estimated_real_markers(
                    pose_hat, shape_hat, offset_r_flat, offset_t_flat, self.vertex_ids)

                pose_hat_history[-1].append(pose_hat)
                shape_hat_history[-1].append(shape_hat)
                joints_hat_history[-1].append(joints_hat)
                markers_hat_history[-1].append(marker_pos_hat)
                markers_ori_hat_history[-1].append(marker_ori_hat)

            pose_hat_final = pose_hat_history[-1][-1].reshape((batch_size, seq_length, -1))
            shape_hat_final = shape_hat_history[-1][-1].reshape((batch_size, seq_length, -1))
            joints_hat_final = joints_hat_history[-1][-1].reshape((batch_size, seq_length, -1))

            all_model_out['pose_hat'].append(pose_hat_final[:, :, 3:])
            all_model_out['root_ori_hat'].append(pose_hat_final[:, :, :3])
            all_model_out['shape_hat'].append(shape_hat_final)
            all_model_out['joints_hat'].append(joints_hat_final)

        # History is kept in nested list of size (n_windows, n_history), merge to list of size (n_history, ).
        def _reshape(in_, out_):
            for h in range(self.N + 1):
                tmp = []
                for k in range(len(in_)):
                    dof = in_[k][h].shape[-1]
                    tmp.append(in_[k][h].reshape((batch.batch_size, -1, dof)))
                out_.append(torch.cat(tmp, dim=1))

        self.pose_hat_history = []
        self.shape_hat_history = []
        self.joints_hat_history = []
        self.markers_hat_history = []
        self.markers_ori_hat_history = []
        _reshape(pose_hat_history, self.pose_hat_history)
        _reshape(shape_hat_history, self.shape_hat_history)
        _reshape(joints_hat_history, self.joints_hat_history)
        _reshape(markers_hat_history, self.markers_hat_history)
        _reshape(markers_ori_hat_history, self.markers_ori_hat_history)

        model_out = {k: torch.cat(all_model_out[k], dim=1) for k in all_model_out}
        return model_out