def _transform_observations()

in occant_baselines/rl/policy.py [0:0]


    def _transform_observations(self, inputs, dx):
        """Converts observations from t-1 to coordinate frame for t.
        """
        # ====================== Transform past egocentric map ========================
        if "ego_map_t_1" in inputs:
            ego_map_t_1 = inputs["ego_map_t_1"]
            ego_map_t_1_trans = self._bottom_row_spatial_transform(
                ego_map_t_1, dx, invert=True
            )
            inputs["ego_map_t_1"] = ego_map_t_1_trans
        occ_cfg = self.projection_unit.main.config
        # ========================= Transform rgb and depth ===========================
        if "depth_t_1" in inputs:
            device = inputs["depth_t_1"].device
            depth_t_1 = inputs["depth_t_1"]
            if "K" not in self._cache.keys():
                # Project images from previous camera pose to current camera pose
                # Compute intrinsic camera matrix
                hfov = math.radians(occ_cfg.EGO_PROJECTION.hfov)
                vfov = math.radians(occ_cfg.EGO_PROJECTION.vfov)
                K = torch.Tensor(
                    [
                        [1 / math.tan(hfov / 2.0), 0.0, 0.0, 0.0],
                        [0.0, 1 / math.tan(vfov / 2.0), 0.0, 0.0],
                        [0.0, 0.0, 1.0, 0.0],
                        [0.0, 0.0, 0.0, 1.0],
                    ]
                ).to(
                    device
                )  # (4, 4)
                self._cache["K"] = K.cpu()
            else:
                K = self._cache["K"].to(device)
            H, W = depth_t_1.shape[2:]
            min_depth = occ_cfg.EGO_PROJECTION.min_depth
            max_depth = occ_cfg.EGO_PROJECTION.max_depth
            depth_t_1_unnorm = depth_t_1 * (max_depth - min_depth) + min_depth
            if "xs" not in self._cache.keys():
                xs, ys = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H))
                xs = torch.Tensor(xs.reshape(1, H, W)).to(device).unsqueeze(0)
                ys = torch.Tensor(ys.reshape(1, H, W)).to(device).unsqueeze(0)
                self._cache["xs"] = xs.cpu()
                self._cache["ys"] = ys.cpu()
            else:
                xs = self._cache["xs"].to(device)
                ys = self._cache["ys"].to(device)
            # Unproject
            # negate depth as the camera looks along -Z
            xys = torch.stack(
                [
                    xs * depth_t_1_unnorm,
                    ys * depth_t_1_unnorm,
                    -depth_t_1_unnorm,
                    torch.ones_like(depth_t_1_unnorm),
                ],
                dim=4,
            )  # (bs, 1, H, W, 4)
            # Points in the target (camera 2)
            xys = rearrange(xys, "b () h w f -> b (h w) f")
            if "invK" not in self._cache.keys():
                invK = torch.inverse(K)
                self._cache["invK"] = invK.cpu()
            else:
                invK = self._cache["invK"].to(device)
            xy_c2 = torch.matmul(xys, invK.unsqueeze(0))
            # ================ Camera 2 --> Camera 1 transformation ===============
            # We need the target to source transformation to warp from camera 1
            # to camera 2. In dx, dx[:, 0] is -Z, dx[:, 1] is X and dx[:, 2] is
            # rotation from -Z to X.
            translation = torch.stack(
                [dx[:, 1], torch.zeros_like(dx[:, 1]), -dx[:, 0]], dim=1
            )  # (bs, 3)
            T_world_camera2 = torch.zeros(xy_c2.shape[0], 4, 4).to(device)
            # Right-hand-rule rotation about Y axis
            cos_theta = torch.cos(-dx[:, 2])
            sin_theta = torch.sin(-dx[:, 2])
            T_world_camera2[:, 0, 0].copy_(cos_theta)
            T_world_camera2[:, 0, 2].copy_(sin_theta)
            T_world_camera2[:, 1, 1].fill_(1.0)
            T_world_camera2[:, 2, 0].copy_(-sin_theta)
            T_world_camera2[:, 2, 2].copy_(cos_theta)
            T_world_camera2[:, :3, 3].copy_(translation)
            T_world_camera2[:, 3, 3].fill_(1.0)
            # Transformation matrix from camera 2 --> world.
            T_camera1_camera2 = T_world_camera2  # (bs, 4, 4)
            xy_c1 = torch.matmul(
                T_camera1_camera2, xy_c2.transpose(1, 2)
            )  # (bs, 4, HW)
            # Convert camera coordinates to image coordinates
            xy_newimg = torch.matmul(K, xy_c1)  # (bs, 4, HW)
            xy_newimg = xy_newimg.transpose(1, 2)  # (bs, HW, 4)
            xys_newimg = xy_newimg[:, :, :2] / (
                -xy_newimg[:, :, 2:3] + 1e-8
            )  # (bs, HW, 2)
            # Flip back to y-down to match array indexing
            xys_newimg[:, :, 1] *= -1  # (bs, HW, 2)
            # ================== Apply warp to RGB, Depth images ==================
            sampler = rearrange(xys_newimg, "b (h w) f -> b h w f", h=H, w=W)
            depth_t_1_trans = F.grid_sample(depth_t_1, sampler, padding_mode="zeros")
            inputs["depth_t_1"] = depth_t_1_trans
            if "rgb_t_1" in inputs:
                rgb_t_1 = inputs["rgb_t_1"]
                rgb_t_1_trans = F.grid_sample(rgb_t_1, sampler, padding_mode="zeros")
                inputs["rgb_t_1"] = rgb_t_1_trans

        return inputs