in occant_baselines/rl/occant_exp_trainer.py [0:0]
def _prepare_batch(self, observations, prev_batch=None, device=None, actions=None):
imH, imW = self.config.RL.ANS.image_scale_hw
device = self.device if device is None else device
batch = batch_obs(observations, device=device)
if batch["rgb"].size(1) != imH or batch["rgb"].size(2) != imW:
rgb = rearrange(batch["rgb"], "b h w c -> b c h w")
rgb = F.interpolate(rgb, (imH, imW), mode="bilinear")
batch["rgb"] = rearrange(rgb, "b c h w -> b h w c")
if batch["depth"].size(1) != imH or batch["depth"].size(2) != imW:
depth = rearrange(batch["depth"], "b h w c -> b c h w")
depth = F.interpolate(depth, (imH, imW), mode="nearest")
batch["depth"] = rearrange(depth, "b c h w -> b h w c")
# Compute ego_map_gt from depth
ego_map_gt_b = self.depth_projection_net(
rearrange(batch["depth"], "b h w c -> b c h w")
)
batch["ego_map_gt"] = rearrange(ego_map_gt_b, "b c h w -> b h w c")
if actions is None:
# Initialization condition
# If pose estimates are not available, set the initial estimate to zeros.
if "pose" not in batch:
# Set initial pose estimate to zero
batch["pose"] = torch.zeros(self.envs.num_envs, 3).to(self.device)
batch["prev_actions"] = torch.zeros(self.envs.num_envs, 1).to(self.device)
else:
# Rollouts condition
# If pose estimates are not available, compute them from action taken.
if "pose" not in batch:
assert prev_batch is not None
actions_delta = self._convert_actions_to_delta(actions)
batch["pose"] = add_pose(prev_batch["pose"], actions_delta)
batch["prev_actions"] = actions
return batch