in models/future_prediction.py [0:0]
def forward(self, feats, target_shape):
"""
Args:
feats: tensor of shape (B, T, C)
target_shape: shape of the output (B, T', n_output)
"""
addl_endpoints = {}
if feats.ndim == 2:
# add back the temporal dimension, which was likely mean pooled
feats = feats.unsqueeze(1)
# Decide the output len based on the target_shape
if len(target_shape) == 3:
output_len = target_shape[1]
elif self.training or self.output_len_eval < 0:
# If training mode or output_len for eval has not been set
output_len = self.output_len
else: # eval mode
output_len = self.output_len_eval
# Keep track
full_inp_feats = feats
if self.assign_to_centroids:
# Unsqueeze only to be compatible with the 1 channel inputs -- that
# will get squeezed out later
feats = self.assigner(feats).unsqueeze(-1)
# The time dimension in already in the middle -> B, T, C
# That's what huggingface version needs:
# (batch_size, sequence_length, hidden_size)
if self.in_features == 1 or self.assign_to_centroids:
# This is a quantized input, so cast it to long, and remove the
# last singleton dimension
assert feats.size(-1) == 1
feats = feats.squeeze(-1).long()
# Keep only the first N, this is used when the model is given
# input more frames than it should be using for prediction. The other
# future is used to incur loss during training, but shouldn't otherwise
# be used, so dropping those features
full_orig_feats = feats
inp_feats = full_inp_feats
if self.drop_last_n != 0:
logging.warning('This should be used very carefully, ideally only '
'for debugging. The padding can lead to some '
'frames from the actual clip to leak into the '
'past clip, even after dropping last n. So even '
'after dropping the model might end up seeing '
'frames that are beyond the tau_a.')
feats = feats[:, :-self.drop_last_n]
inp_feats = inp_feats[:, :-self.drop_last_n]
# Keep track
orig_feats_len = feats.size(1)
# Reduce the dimensionality, since not using the GPT encoding matrix,
# since I don't have a "token" representation
feats = self.encoder(feats)
orig_feats_encoded = feats
past = None
all_outputs = []
all_outputs_decoded = []
for output_id in range(output_len):
pred_so_far = sum([el.size(1) for el in all_outputs])
position_ids = torch.arange(pred_so_far,
pred_so_far + feats.size(1),
dtype=torch.long,
device=feats.device)
# The past output will encode the previous past AND the new input
# (you can check the output, it keeps increasing)
# Got this from
# https://huggingface.co/transformers/quickstart.html#using-the-past
outputs = self.gpt_model(inputs_embeds=feats,
past_key_values=past,
position_ids=position_ids)
last_hidden_state = outputs.last_hidden_state
past = outputs.past_key_values
all_outputs.append(last_hidden_state)
# For visualization later, if output_attentions was passed into gpt
if outputs.attentions is not None:
# dimensions will be (batch_size, nlayers, nheads, seqlen, seqlen)
addl_endpoints[f'gpt2_att_{output_id}'] = torch.stack(
outputs.attentions).transpose(0, 1)
# Map back to the original feature dimension
all_outputs_decoded.append(self.decoder(last_hidden_state))
# hidden_states[-1] or last_hidden_state is the embedding from the
# final layer. Not using logits (earlier was using the LMHead model
# that returned logits) since that is already decoded to vocab size
# and I want to have control over the weights of that final matrix
# Also, the input for the next would be encodings, so need to
# access the encodings directly
if self.quantize_before_rollout:
assert isinstance(self.encoder, nn.Embedding)
feats = self.encoder(
all_outputs_decoded[-1][:, -1:, :].argmax(dim=-1))
else:
feats = last_hidden_state[:, -1:, :]
all_outputs = torch.cat(all_outputs, dim=1)
all_outputs_decoded = torch.cat(all_outputs_decoded, dim=1)
# Compute a loss on future prediction (teacher forced)
losses = {}
if self.future_pred_loss is not None:
num_elts_for_loss = min(full_orig_feats.size(1),
all_outputs_decoded.size(1))
losses = {
'feat':
self.future_pred_loss(
all_outputs_decoded[:, :num_elts_for_loss - 1],
full_orig_feats[:, 1:num_elts_for_loss])
}
# Set all_output as the final output features, and prev as the
# structure to use to get the original features of past
if self.in_features == 1:
prev = orig_feats_encoded
# all_outputs contains the hidden states, the best we will get
# anyway, so that doesn't change
elif self.assign_to_centroids:
prev = inp_feats # For this, I have the orig feats, so use that
# For prediction, use the predicted cluster centers, but use
# features from the original kmeans, not what the embeddings
# that were learnt.. it didn't work with them
all_outputs = self.assigner(all_outputs_decoded.argmax(dim=-1))
else:
prev = inp_feats
all_outputs = all_outputs_decoded
# Return the actual predictions
if self.return_past_too:
# Pad in the GT past (no point using the predicted past when
# we have the actual past)
final = torch.cat((prev, all_outputs[:, orig_feats_len - 1:, :]),
dim=1)
elif output_len > 0:
final = all_outputs[:, -output_len:]
else:
final = all_outputs
if self.avg_last_n > 0:
final = torch.mean(final[:, -self.avg_last_n:, :], dim=1)
# compute the past feature.
assert prev.size(1) == orig_feats_len, (
'If not, need to figure how to deal')
# Now keep the old feature for the first one, and return the predicted
# features shifted by 1 for the rest -- which are as predicted by
# GPT
updated_past_feat = torch.cat(
[prev[:, :1, :], all_outputs[:, :(orig_feats_len - 1)]], dim=1)
return updated_past_feat, final, losses, addl_endpoints