def forward()

in models/future_prediction.py [0:0]


    def forward(self, feats, target_shape):
        """
        Args:
            feats: tensor of shape (B, T, C)
            target_shape: shape of the output (B, T', n_output)
        """
        addl_endpoints = {}
        if feats.ndim == 2:
            # add back the temporal dimension, which was likely mean pooled
            feats = feats.unsqueeze(1)
        # Decide the output len based on the target_shape
        if len(target_shape) == 3:
            output_len = target_shape[1]
        elif self.training or self.output_len_eval < 0:
            # If training mode or output_len for eval has not been set
            output_len = self.output_len
        else:  # eval mode
            output_len = self.output_len_eval
        # Keep track
        full_inp_feats = feats
        if self.assign_to_centroids:
            # Unsqueeze only to be compatible with the 1 channel inputs -- that
            # will get squeezed out later
            feats = self.assigner(feats).unsqueeze(-1)
        # The time dimension in already in the middle -> B, T, C
        # That's what huggingface version needs:
        # (batch_size, sequence_length, hidden_size)
        if self.in_features == 1 or self.assign_to_centroids:
            # This is a quantized input, so cast it to long, and remove the
            # last singleton dimension
            assert feats.size(-1) == 1
            feats = feats.squeeze(-1).long()
        # Keep only the first N, this is used when the model is given
        # input more frames than it should be using for prediction. The other
        # future is used to incur loss during training, but shouldn't otherwise
        # be used, so dropping those features
        full_orig_feats = feats
        inp_feats = full_inp_feats
        if self.drop_last_n != 0:
            logging.warning('This should be used very carefully, ideally only '
                            'for debugging. The padding can lead to some '
                            'frames from the actual clip to leak into the '
                            'past clip, even after dropping last n. So even '
                            'after dropping the model might end up seeing '
                            'frames that are beyond the tau_a.')
            feats = feats[:, :-self.drop_last_n]
            inp_feats = inp_feats[:, :-self.drop_last_n]
        # Keep track
        orig_feats_len = feats.size(1)
        # Reduce the dimensionality, since not using the GPT encoding matrix,
        # since I don't have a "token" representation
        feats = self.encoder(feats)
        orig_feats_encoded = feats
        past = None
        all_outputs = []
        all_outputs_decoded = []
        for output_id in range(output_len):
            pred_so_far = sum([el.size(1) for el in all_outputs])
            position_ids = torch.arange(pred_so_far,
                                        pred_so_far + feats.size(1),
                                        dtype=torch.long,
                                        device=feats.device)
            # The past output will encode the previous past AND the new input
            # (you can check the output, it keeps increasing)
            # Got this from
            # https://huggingface.co/transformers/quickstart.html#using-the-past
            outputs = self.gpt_model(inputs_embeds=feats,
                                     past_key_values=past,
                                     position_ids=position_ids)
            last_hidden_state = outputs.last_hidden_state
            past = outputs.past_key_values
            all_outputs.append(last_hidden_state)
            # For visualization later, if output_attentions was passed into gpt
            if outputs.attentions is not None:
                # dimensions will be (batch_size, nlayers, nheads, seqlen, seqlen)
                addl_endpoints[f'gpt2_att_{output_id}'] = torch.stack(
                    outputs.attentions).transpose(0, 1)
            # Map back to the original feature dimension
            all_outputs_decoded.append(self.decoder(last_hidden_state))
            # hidden_states[-1] or last_hidden_state is the embedding from the
            # final layer. Not using logits (earlier was using the LMHead model
            # that returned logits) since that is already decoded to vocab size
            # and I want to have control over the weights of that final matrix
            # Also, the input for the next would be encodings, so need to
            # access the encodings directly
            if self.quantize_before_rollout:
                assert isinstance(self.encoder, nn.Embedding)
                feats = self.encoder(
                    all_outputs_decoded[-1][:, -1:, :].argmax(dim=-1))
            else:
                feats = last_hidden_state[:, -1:, :]
        all_outputs = torch.cat(all_outputs, dim=1)
        all_outputs_decoded = torch.cat(all_outputs_decoded, dim=1)
        # Compute a loss on future prediction (teacher forced)
        losses = {}
        if self.future_pred_loss is not None:
            num_elts_for_loss = min(full_orig_feats.size(1),
                                    all_outputs_decoded.size(1))
            losses = {
                'feat':
                self.future_pred_loss(
                    all_outputs_decoded[:, :num_elts_for_loss - 1],
                    full_orig_feats[:, 1:num_elts_for_loss])
            }
        # Set all_output as the final output features, and prev as the
        # structure to use to get the original features of past
        if self.in_features == 1:
            prev = orig_feats_encoded
            # all_outputs contains the hidden states, the best we will get
            # anyway, so that doesn't change
        elif self.assign_to_centroids:
            prev = inp_feats  # For this, I have the orig feats, so use that
            # For prediction, use the predicted cluster centers, but use
            # features from the original kmeans, not what the embeddings
            # that were learnt.. it didn't work with them
            all_outputs = self.assigner(all_outputs_decoded.argmax(dim=-1))
        else:
            prev = inp_feats
            all_outputs = all_outputs_decoded
        # Return the actual predictions
        if self.return_past_too:
            # Pad in the GT past (no point using the predicted past when
            # we have the actual past)
            final = torch.cat((prev, all_outputs[:, orig_feats_len - 1:, :]),
                              dim=1)
        elif output_len > 0:
            final = all_outputs[:, -output_len:]
        else:
            final = all_outputs
        if self.avg_last_n > 0:
            final = torch.mean(final[:, -self.avg_last_n:, :], dim=1)
        # compute the past feature.
        assert prev.size(1) == orig_feats_len, (
            'If not, need to figure how to deal')
        # Now keep the old feature for the first one, and return the predicted
        # features shifted by 1 for the rest -- which are as predicted by
        # GPT
        updated_past_feat = torch.cat(
            [prev[:, :1, :], all_outputs[:, :(orig_feats_len - 1)]], dim=1)
        return updated_past_feat, final, losses, addl_endpoints