def _prepare_batch()

in occant_baselines/rl/occant_exp_trainer.py [0:0]


    def _prepare_batch(self, observations, prev_batch=None, device=None, actions=None):
        imH, imW = self.config.RL.ANS.image_scale_hw
        device = self.device if device is None else device
        batch = batch_obs(observations, device=device)
        if batch["rgb"].size(1) != imH or batch["rgb"].size(2) != imW:
            rgb = rearrange(batch["rgb"], "b h w c -> b c h w")
            rgb = F.interpolate(rgb, (imH, imW), mode="bilinear")
            batch["rgb"] = rearrange(rgb, "b c h w -> b h w c")
        if batch["depth"].size(1) != imH or batch["depth"].size(2) != imW:
            depth = rearrange(batch["depth"], "b h w c -> b c h w")
            depth = F.interpolate(depth, (imH, imW), mode="nearest")
            batch["depth"] = rearrange(depth, "b c h w -> b h w c")
        # Compute ego_map_gt from depth
        ego_map_gt_b = self.depth_projection_net(
            rearrange(batch["depth"], "b h w c -> b c h w")
        )
        batch["ego_map_gt"] = rearrange(ego_map_gt_b, "b c h w -> b h w c")
        if actions is None:
            # Initialization condition
            # If pose estimates are not available, set the initial estimate to zeros.
            if "pose" not in batch:
                # Set initial pose estimate to zero
                batch["pose"] = torch.zeros(self.envs.num_envs, 3).to(self.device)
            batch["prev_actions"] = torch.zeros(self.envs.num_envs, 1).to(self.device)
        else:
            # Rollouts condition
            # If pose estimates are not available, compute them from action taken.
            if "pose" not in batch:
                assert prev_batch is not None
                actions_delta = self._convert_actions_to_delta(actions)
                batch["pose"] = add_pose(prev_batch["pose"], actions_delta)
            batch["prev_actions"] = actions

        return batch