in agents/nets.py [0:0]
def forward(self,
vid,
vid_is_solved,
n_hist_frames=3,
n_fwd_times=1,
n_fwd_times_incur_loss=999999,
run_decode=False,
compute_losses=False,
need_intermediate=False,
autoenc_loss_ratio=0.0,
nslices=-1):
"""
Args:
vid: (BxTxNobjxHxW) The input video
vid_is_solved: (Bx1) Whether the video is solved in the end of not.
Could be None at test time.
n_hist_frames: (int) Number of frames to use as history for
prediction
n_fwd_times: (int) How many times to run the forward dynamics model
n_fwd_times_incur_loss (int): Upto how many of these forwards to
incur loss on.
run_decode: (bool) Decode the features into pixel output
compute_losses: Should be set at train time. Will compute losses,
whatever it can given the data (eg, if vid_is_solved is not
passed to the function, it will not compute the CE loss).
need_intermediate (bool): Set true if you want to run the dynamics
model and need all the intermediate results. Else, will return
a list with only 1 element, the final output.
autoenc_loss_ratio (float btw 0-1): Set to 1 to run auto-encoder
style loss on all frames when run_decode is set.
num_slices (int): See in the _slice_for_dyn fn
Returns:
clip_feat: BxTxD
"""
vid_preproc = self.preproc.preprocess_vid(vid)
obj_feat = self.enc(vid_preproc)
clip_hist = self._slice_for_dyn(obj_feat,
n_hist_frames,
nslices=nslices)
vid_hist = self._slice_for_dyn(vid_preproc,
n_hist_frames,
nslices=nslices)
assert clip_hist.shape[1] == n_hist_frames
clip_hist = self.interactor(clip_hist)
clip_preds, clip_preds_pix, clip_preds_addl_losses = self._forward_dyn(
clip_hist, vid_hist, n_fwd_times, need_intermediate)
if run_decode:
clip_preds_pix = self._forward_dec(clip_preds, clip_preds_pix)
else:
clip_preds_pix = [None] * len(clip_preds)
# Compute the solved or not, will only do for the ones asked for
clip_preds_solved = self._cls(
combine_obj_pixels(clip_hist, 2), combine_obj_pixels(vid_hist, 2),
[combine_obj_pixels(el, 1) for el in clip_preds],
[combine_obj_pixels(el, 1) for el in clip_preds_pix])
all_losses = []
clip_preds_pix_unpreproc_for_loss = [
self.preproc.unpreprocess_frame_for_loss(el)
for el in clip_preds_pix
]
if compute_losses:
for i in range(min(len(clip_preds), n_fwd_times_incur_loss)):
# Compute losses at each prediction step, if need_intermediate
# is set. Else, it will only return a single output
# (at the last prediction), and then we can only incur loss at
# that point.
if not need_intermediate:
assert len(clip_preds) == 1
pred_id = -1
# Only loss on predicting the final rolled out obs
this_fwd_times = n_fwd_times
else:
assert len(clip_preds) == n_fwd_times
pred_id = i
this_fwd_times = i + 1
all_losses.append(
self._compute_losses(
# For the loss, using only the last prediction (for now)
clip_preds[pred_id],
combine_obj_pixels(
clip_preds_pix_unpreproc_for_loss[pred_id], 1),
obj_feat,
combine_obj_pixels(vid, 2),
n_hist_frames,
this_fwd_times))
all_losses = average_losses(all_losses)
all_losses.update(average_losses(clip_preds_addl_losses))
all_losses.update(
self.solved_or_not_loss(clip_preds_solved, vid_is_solved))
# Add losses on the provided frames if requested
if run_decode and autoenc_loss_ratio > 0:
all_losses.update(
self.autoencoder_loss(combine_obj_pixels(vid, 2), obj_feat,
autoenc_loss_ratio))
clip_preds_pix_unpreproc = [
combine_obj_pixels(self.preproc.unpreprocess_frame_after_loss(el),
1) for el in clip_preds_pix_unpreproc_for_loss
]
all_preds = {
'feats': clip_preds,
'is_solved': clip_preds_solved,
'pixels': clip_preds_pix_unpreproc,
}
return all_preds, all_losses