def predict_deltas()

in occant_baselines/rl/policy.py [0:0]


    def predict_deltas(self, x, masks=None):
        # Transpose multichannel inputs
        st_1 = process_image(x["rgb_at_t_1"], self.img_mean_t, self.img_std_t)
        dt_1 = transpose_image(x["depth_at_t_1"])
        ego_map_gt_at_t_1 = transpose_image(x["ego_map_gt_at_t_1"])
        st = process_image(x["rgb_at_t"], self.img_mean_t, self.img_std_t)
        dt = transpose_image(x["depth_at_t"])
        ego_map_gt_at_t = transpose_image(x["ego_map_gt_at_t"])
        # This happens only for a baseline
        if (
            "ego_map_gt_anticipated_at_t_1" in x
            and x["ego_map_gt_anticipated_at_t_1"] is not None
        ):
            ego_map_gt_anticipated_at_t_1 = transpose_image(
                x["ego_map_gt_anticipated_at_t_1"]
            )
            ego_map_gt_anticipated_at_t = transpose_image(
                x["ego_map_gt_anticipated_at_t"]
            )
        else:
            ego_map_gt_anticipated_at_t_1 = None
            ego_map_gt_anticipated_at_t = None
        # Compute past and current egocentric maps
        bs = st_1.size(0)
        pu_inputs_t_1 = {
            "rgb": st_1,
            "depth": dt_1,
            "ego_map_gt": ego_map_gt_at_t_1,
            "ego_map_gt_anticipated": ego_map_gt_anticipated_at_t_1,
        }
        pu_inputs_t = {
            "rgb": st,
            "depth": dt,
            "ego_map_gt": ego_map_gt_at_t,
            "ego_map_gt_anticipated": ego_map_gt_anticipated_at_t,
        }
        pu_inputs = self._safe_cat(pu_inputs_t_1, pu_inputs_t)
        pu_outputs = self.projection_unit(pu_inputs)
        pu_outputs_t = {k: v[bs:] for k, v in pu_outputs.items()}
        pt_1, pt = pu_outputs["occ_estimate"][:bs], pu_outputs["occ_estimate"][bs:]
        # Compute relative pose
        dx = subtract_pose(x["pose_at_t_1"], x["pose_at_t"])
        # Estimate pose
        dx_hat = dx
        xt_hat = x["pose_at_t"]
        all_pose_outputs = None
        if not self.config.ignore_pose_estimator:
            all_pose_outputs = {}
            pose_inputs = {}
            if "rgb" in self.config.pose_predictor_inputs:
                pose_inputs["rgb_t_1"] = st_1
                pose_inputs["rgb_t"] = st
            if "depth" in self.config.pose_predictor_inputs:
                pose_inputs["depth_t_1"] = dt_1
                pose_inputs["depth_t"] = dt
            if "ego_map" in self.config.pose_predictor_inputs:
                pose_inputs["ego_map_t_1"] = pt_1
                pose_inputs["ego_map_t"] = pt
            if self.config.detach_map:
                for k in pose_inputs.keys():
                    pose_inputs[k] = pose_inputs[k].detach()
            n_pose_inputs = self._transform_observations(pose_inputs, dx)
            pose_outputs = self.pose_estimator(n_pose_inputs)
            dx_hat = add_pose(dx, pose_outputs["pose"])
            all_pose_outputs["pose_outputs"] = pose_outputs
            # Estimate global pose
            xt_hat = add_pose(x["pose_hat_at_t_1"], dx_hat)
        # Zero out pose prediction based on the mask
        if masks is not None:
            xt_hat = xt_hat * masks
            dx_hat = dx_hat * masks
        outputs = {
            "pt": pt,
            "dx_hat": dx_hat,
            "xt_hat": xt_hat,
            "all_pu_outputs": pu_outputs_t,
            "all_pose_outputs": all_pose_outputs,
        }
        if "ego_map_hat" in pu_outputs_t:
            outputs["ego_map_hat_at_t"] = pu_outputs_t["ego_map_hat"]
        return outputs