in occant_baselines/rl/policy.py [0:0]
def predict_deltas(self, x, masks=None):
# Transpose multichannel inputs
st_1 = process_image(x["rgb_at_t_1"], self.img_mean_t, self.img_std_t)
dt_1 = transpose_image(x["depth_at_t_1"])
ego_map_gt_at_t_1 = transpose_image(x["ego_map_gt_at_t_1"])
st = process_image(x["rgb_at_t"], self.img_mean_t, self.img_std_t)
dt = transpose_image(x["depth_at_t"])
ego_map_gt_at_t = transpose_image(x["ego_map_gt_at_t"])
# This happens only for a baseline
if (
"ego_map_gt_anticipated_at_t_1" in x
and x["ego_map_gt_anticipated_at_t_1"] is not None
):
ego_map_gt_anticipated_at_t_1 = transpose_image(
x["ego_map_gt_anticipated_at_t_1"]
)
ego_map_gt_anticipated_at_t = transpose_image(
x["ego_map_gt_anticipated_at_t"]
)
else:
ego_map_gt_anticipated_at_t_1 = None
ego_map_gt_anticipated_at_t = None
# Compute past and current egocentric maps
bs = st_1.size(0)
pu_inputs_t_1 = {
"rgb": st_1,
"depth": dt_1,
"ego_map_gt": ego_map_gt_at_t_1,
"ego_map_gt_anticipated": ego_map_gt_anticipated_at_t_1,
}
pu_inputs_t = {
"rgb": st,
"depth": dt,
"ego_map_gt": ego_map_gt_at_t,
"ego_map_gt_anticipated": ego_map_gt_anticipated_at_t,
}
pu_inputs = self._safe_cat(pu_inputs_t_1, pu_inputs_t)
pu_outputs = self.projection_unit(pu_inputs)
pu_outputs_t = {k: v[bs:] for k, v in pu_outputs.items()}
pt_1, pt = pu_outputs["occ_estimate"][:bs], pu_outputs["occ_estimate"][bs:]
# Compute relative pose
dx = subtract_pose(x["pose_at_t_1"], x["pose_at_t"])
# Estimate pose
dx_hat = dx
xt_hat = x["pose_at_t"]
all_pose_outputs = None
if not self.config.ignore_pose_estimator:
all_pose_outputs = {}
pose_inputs = {}
if "rgb" in self.config.pose_predictor_inputs:
pose_inputs["rgb_t_1"] = st_1
pose_inputs["rgb_t"] = st
if "depth" in self.config.pose_predictor_inputs:
pose_inputs["depth_t_1"] = dt_1
pose_inputs["depth_t"] = dt
if "ego_map" in self.config.pose_predictor_inputs:
pose_inputs["ego_map_t_1"] = pt_1
pose_inputs["ego_map_t"] = pt
if self.config.detach_map:
for k in pose_inputs.keys():
pose_inputs[k] = pose_inputs[k].detach()
n_pose_inputs = self._transform_observations(pose_inputs, dx)
pose_outputs = self.pose_estimator(n_pose_inputs)
dx_hat = add_pose(dx, pose_outputs["pose"])
all_pose_outputs["pose_outputs"] = pose_outputs
# Estimate global pose
xt_hat = add_pose(x["pose_hat_at_t_1"], dx_hat)
# Zero out pose prediction based on the mask
if masks is not None:
xt_hat = xt_hat * masks
dx_hat = dx_hat * masks
outputs = {
"pt": pt,
"dx_hat": dx_hat,
"xt_hat": xt_hat,
"all_pu_outputs": pu_outputs_t,
"all_pose_outputs": all_pose_outputs,
}
if "ego_map_hat" in pu_outputs_t:
outputs["ego_map_hat_at_t"] = pu_outputs_t["ego_map_hat"]
return outputs