in occant_baselines/rl/occant_nav_trainer.py [0:0]
def _prepare_batch(self, observations, device=None, actions=None):
imH, imW = self.config.RL.ANS.image_scale_hw
device = self.device if device is None else device
batch = batch_obs(observations, device=device)
if batch["rgb"].size(1) != imH or batch["rgb"].size(2) != imW:
rgb = rearrange(batch["rgb"], "b h w c -> b c h w")
rgb = F.interpolate(rgb, (imH, imW), mode="bilinear")
batch["rgb"] = rearrange(rgb, "b c h w -> b h w c")
if batch["depth"].size(1) != imH or batch["depth"].size(2) != imW:
depth = rearrange(batch["depth"], "b h w c -> b c h w")
depth = F.interpolate(depth, (imH, imW), mode="nearest")
batch["depth"] = rearrange(depth, "b c h w -> b h w c")
# Compute ego_map_gt from depth
ego_map_gt_b = self.depth_projection_net(
rearrange(batch["depth"], "b h w c -> b c h w")
)
batch["ego_map_gt"] = rearrange(ego_map_gt_b, "b c h w -> b h w c")
# Add previous action to batch as well
batch["prev_actions"] = self.prev_actions
# Add a rough pose estimate if GT pose is not available
if "pose" not in batch:
if self.prev_batch is None:
# Set initial pose estimate to zero
batch["pose"] = torch.zeros(self.envs.num_envs, 3).to(self.device)
else:
actions_delta = self._convert_actions_to_delta(self.prev_actions)
batch["pose"] = add_pose(self.prev_batch["pose"], actions_delta)
return batch