in func/train_eval_ops.py [0:0]
def __call__(
self,
data: Union[Dict[str, torch.Tensor], # If dict
Tuple[torch.Tensor, torch.Tensor]], # vid, target
train_mode: bool = True):
data = self._basic_preproc(data, train_mode)
video = data['video'].to(self.device, non_blocking=True)
target = {
key: val.to(self.device, non_blocking=True)
for key, val in data['target'].items()
}
batch_size = video.size(0)
if train_mode:
# At test time, I don't sample the extra future video, since
# that is only used during training
all_videos = [video]
nfutures = len(
[key for key in data.keys() if key.startswith(FUTURE_PREFIX)])
for i in range(nfutures):
future_vid = data[f'{FUTURE_PREFIX}_{i}_video'].to(
self.device, non_blocking=True)
all_videos.append(future_vid)
video = torch.cat(all_videos, dim=0) # Add to batch dim
outputs_full, aux_losses = self.model(video)
# Just the actual video for outputs
outputs = {key: val[:batch_size] for key, val in outputs_full.items()}
# if self.cls_loss_wt != 0:
# Doing this makes some layers not have gradients and it gives errors,
# so just leaving it here for now. The gradient should be 0 anyway
losses, accuracies = self.cls_loss_acc_fn(outputs, target)
losses.update(aux_losses)
losses['cls'] = losses['cls']
if train_mode:
# Incur the regression losses, for each of the futures
reg_losses = []
if self.incur_loss_style == 'separately':
for i in range(nfutures):
future_feats = outputs_full[self.future_target][
(i + 1) * batch_size:(i + 2) * batch_size]
if self.cumulative_future:
future_feats = torch.cumsum(future_feats, 0)
# Divide by the position to get mean of features until then
future_feats = future_feats / (torch.range(
1,
future_feats.size(0),
device=future_feats.device,
dtype=future_feats.dtype).unsqueeze(1))
loss = self.reg_criterion(outputs['future_projected'],
future_feats)
reg_losses.append(loss)
final_reg_loss = hydra.utils.call(self.combine_future_losses,
torch.stack(reg_losses))
elif self.incur_loss_style == 'together':
future_feats = outputs_full[self.future_target][batch_size:]
future_feats = future_feats.reshape(
(-1, batch_size, future_feats.size(-1))).transpose(0, 1)
final_reg_loss = self.reg_criterion(
outputs['future_projected'], future_feats)
else:
raise NotImplementedError(self.incur_loss_style)
losses['reg'] = final_reg_loss
return data, outputs, losses, accuracies