def forward()

in agents/nets.py [0:0]


    def forward(self,
                vid,
                vid_is_solved,
                n_hist_frames=3,
                n_fwd_times=1,
                n_fwd_times_incur_loss=999999,
                run_decode=False,
                compute_losses=False,
                need_intermediate=False,
                autoenc_loss_ratio=0.0,
                nslices=-1):
        """
        Args:
            vid: (BxTxNobjxHxW) The input video
            vid_is_solved: (Bx1) Whether the video is solved in the end of not.
                Could be None at test time.
            n_hist_frames: (int) Number of frames to use as history for
                prediction
            n_fwd_times: (int) How many times to run the forward dynamics model
            n_fwd_times_incur_loss (int): Upto how many of these forwards to
                incur loss on.
            run_decode: (bool) Decode the features into pixel output
            compute_losses: Should be set at train time. Will compute losses,
                whatever it can given the data (eg, if vid_is_solved is not
                passed to the function, it will not compute the CE loss).
            need_intermediate (bool): Set true if you want to run the dynamics
                model and need all the intermediate results. Else, will return
                a list with only 1 element, the final output.
            autoenc_loss_ratio (float btw 0-1): Set to 1 to run auto-encoder
                style loss on all frames when run_decode is set.
            num_slices (int): See in the _slice_for_dyn fn
        Returns:
            clip_feat: BxTxD
        """
        vid_preproc = self.preproc.preprocess_vid(vid)
        obj_feat = self.enc(vid_preproc)
        clip_hist = self._slice_for_dyn(obj_feat,
                                        n_hist_frames,
                                        nslices=nslices)
        vid_hist = self._slice_for_dyn(vid_preproc,
                                       n_hist_frames,
                                       nslices=nslices)
        assert clip_hist.shape[1] == n_hist_frames
        clip_hist = self.interactor(clip_hist)
        clip_preds, clip_preds_pix, clip_preds_addl_losses = self._forward_dyn(
            clip_hist, vid_hist, n_fwd_times, need_intermediate)
        if run_decode:
            clip_preds_pix = self._forward_dec(clip_preds, clip_preds_pix)
        else:
            clip_preds_pix = [None] * len(clip_preds)
        # Compute the solved or not, will only do for the ones asked for
        clip_preds_solved = self._cls(
            combine_obj_pixels(clip_hist, 2), combine_obj_pixels(vid_hist, 2),
            [combine_obj_pixels(el, 1) for el in clip_preds],
            [combine_obj_pixels(el, 1) for el in clip_preds_pix])
        all_losses = []
        clip_preds_pix_unpreproc_for_loss = [
            self.preproc.unpreprocess_frame_for_loss(el)
            for el in clip_preds_pix
        ]
        if compute_losses:
            for i in range(min(len(clip_preds), n_fwd_times_incur_loss)):
                # Compute losses at each prediction step, if need_intermediate
                # is set. Else, it will only return a single output
                # (at the last prediction), and then we can only incur loss at
                # that point.
                if not need_intermediate:
                    assert len(clip_preds) == 1
                    pred_id = -1
                    # Only loss on predicting the final rolled out obs
                    this_fwd_times = n_fwd_times
                else:
                    assert len(clip_preds) == n_fwd_times
                    pred_id = i
                    this_fwd_times = i + 1
                all_losses.append(
                    self._compute_losses(
                        # For the loss, using only the last prediction (for now)
                        clip_preds[pred_id],
                        combine_obj_pixels(
                            clip_preds_pix_unpreproc_for_loss[pred_id], 1),
                        obj_feat,
                        combine_obj_pixels(vid, 2),
                        n_hist_frames,
                        this_fwd_times))
            all_losses = average_losses(all_losses)
            all_losses.update(average_losses(clip_preds_addl_losses))
            all_losses.update(
                self.solved_or_not_loss(clip_preds_solved, vid_is_solved))
            # Add losses on the provided frames if requested
            if run_decode and autoenc_loss_ratio > 0:
                all_losses.update(
                    self.autoencoder_loss(combine_obj_pixels(vid, 2), obj_feat,
                                          autoenc_loss_ratio))
        clip_preds_pix_unpreproc = [
            combine_obj_pixels(self.preproc.unpreprocess_frame_after_loss(el),
                               1) for el in clip_preds_pix_unpreproc_for_loss
        ]
        all_preds = {
            'feats': clip_preds,
            'is_solved': clip_preds_solved,
            'pixels': clip_preds_pix_unpreproc,
        }
        return all_preds, all_losses