in occant_baselines/rl/policy.py [0:0]
def _transform_observations(self, inputs, dx):
"""Converts observations from t-1 to coordinate frame for t.
"""
# ====================== Transform past egocentric map ========================
if "ego_map_t_1" in inputs:
ego_map_t_1 = inputs["ego_map_t_1"]
ego_map_t_1_trans = self._bottom_row_spatial_transform(
ego_map_t_1, dx, invert=True
)
inputs["ego_map_t_1"] = ego_map_t_1_trans
occ_cfg = self.projection_unit.main.config
# ========================= Transform rgb and depth ===========================
if "depth_t_1" in inputs:
device = inputs["depth_t_1"].device
depth_t_1 = inputs["depth_t_1"]
if "K" not in self._cache.keys():
# Project images from previous camera pose to current camera pose
# Compute intrinsic camera matrix
hfov = math.radians(occ_cfg.EGO_PROJECTION.hfov)
vfov = math.radians(occ_cfg.EGO_PROJECTION.vfov)
K = torch.Tensor(
[
[1 / math.tan(hfov / 2.0), 0.0, 0.0, 0.0],
[0.0, 1 / math.tan(vfov / 2.0), 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
).to(
device
) # (4, 4)
self._cache["K"] = K.cpu()
else:
K = self._cache["K"].to(device)
H, W = depth_t_1.shape[2:]
min_depth = occ_cfg.EGO_PROJECTION.min_depth
max_depth = occ_cfg.EGO_PROJECTION.max_depth
depth_t_1_unnorm = depth_t_1 * (max_depth - min_depth) + min_depth
if "xs" not in self._cache.keys():
xs, ys = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, H))
xs = torch.Tensor(xs.reshape(1, H, W)).to(device).unsqueeze(0)
ys = torch.Tensor(ys.reshape(1, H, W)).to(device).unsqueeze(0)
self._cache["xs"] = xs.cpu()
self._cache["ys"] = ys.cpu()
else:
xs = self._cache["xs"].to(device)
ys = self._cache["ys"].to(device)
# Unproject
# negate depth as the camera looks along -Z
xys = torch.stack(
[
xs * depth_t_1_unnorm,
ys * depth_t_1_unnorm,
-depth_t_1_unnorm,
torch.ones_like(depth_t_1_unnorm),
],
dim=4,
) # (bs, 1, H, W, 4)
# Points in the target (camera 2)
xys = rearrange(xys, "b () h w f -> b (h w) f")
if "invK" not in self._cache.keys():
invK = torch.inverse(K)
self._cache["invK"] = invK.cpu()
else:
invK = self._cache["invK"].to(device)
xy_c2 = torch.matmul(xys, invK.unsqueeze(0))
# ================ Camera 2 --> Camera 1 transformation ===============
# We need the target to source transformation to warp from camera 1
# to camera 2. In dx, dx[:, 0] is -Z, dx[:, 1] is X and dx[:, 2] is
# rotation from -Z to X.
translation = torch.stack(
[dx[:, 1], torch.zeros_like(dx[:, 1]), -dx[:, 0]], dim=1
) # (bs, 3)
T_world_camera2 = torch.zeros(xy_c2.shape[0], 4, 4).to(device)
# Right-hand-rule rotation about Y axis
cos_theta = torch.cos(-dx[:, 2])
sin_theta = torch.sin(-dx[:, 2])
T_world_camera2[:, 0, 0].copy_(cos_theta)
T_world_camera2[:, 0, 2].copy_(sin_theta)
T_world_camera2[:, 1, 1].fill_(1.0)
T_world_camera2[:, 2, 0].copy_(-sin_theta)
T_world_camera2[:, 2, 2].copy_(cos_theta)
T_world_camera2[:, :3, 3].copy_(translation)
T_world_camera2[:, 3, 3].fill_(1.0)
# Transformation matrix from camera 2 --> world.
T_camera1_camera2 = T_world_camera2 # (bs, 4, 4)
xy_c1 = torch.matmul(
T_camera1_camera2, xy_c2.transpose(1, 2)
) # (bs, 4, HW)
# Convert camera coordinates to image coordinates
xy_newimg = torch.matmul(K, xy_c1) # (bs, 4, HW)
xy_newimg = xy_newimg.transpose(1, 2) # (bs, HW, 4)
xys_newimg = xy_newimg[:, :, :2] / (
-xy_newimg[:, :, 2:3] + 1e-8
) # (bs, HW, 2)
# Flip back to y-down to match array indexing
xys_newimg[:, :, 1] *= -1 # (bs, HW, 2)
# ================== Apply warp to RGB, Depth images ==================
sampler = rearrange(xys_newimg, "b (h w) f -> b h w f", h=H, w=W)
depth_t_1_trans = F.grid_sample(depth_t_1, sampler, padding_mode="zeros")
inputs["depth_t_1"] = depth_t_1_trans
if "rgb_t_1" in inputs:
rgb_t_1 = inputs["rgb_t_1"]
rgb_t_1_trans = F.grid_sample(rgb_t_1, sampler, padding_mode="zeros")
inputs["rgb_t_1"] = rgb_t_1_trans
return inputs