def collect_affordance_episodes()

in affordance_seg/collect_dset.py [0:0]


    def collect_affordance_episodes(self):

        os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True)

        # add test episode information to config
        test_episodes = json.load(open(config.EVAL.DATASET))
        self.config.defrost()
        self.config.ENV.TEST_EPISODES = test_episodes
        self.config.ENV.TEST_EPISODE_COUNT = len(test_episodes)
        self.config.freeze()

        # [!!] Load checkpoint, create dir to save rollouts to, and copy checkpoint for reference
        checkpoint_path = self.config.LOAD
        os.makedirs(f'{self.config.OUT_DIR}/episodes/', exist_ok=True)
        shutil.copy(checkpoint_path, f'{self.config.OUT_DIR}/{os.path.basename(checkpoint_path)}')

        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")
        ppo_cfg = self.config.RL.PPO

        logger.info(f"env config: {self.config}")
        self.envs = construct_envs(self.config, get_env_class(self.config.ENV.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        # [!!] Log extra stuff
        logger.info(checkpoint_path)
        logger.info(f"num_steps: {self.config.ENV.NUM_STEPS}")

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = self.batch_obs(observations, self.device)

        current_episode_reward = torch.zeros(
            self.envs.num_envs, 1, device=self.device
        )

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(
            self.config.NUM_PROCESSES, 1, device=self.device, dtype=torch.long
        )
        not_done_masks = torch.zeros(
            self.config.NUM_PROCESSES, 1, device=self.device
        )
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [
            [] for _ in range(self.config.NUM_PROCESSES)
        ]  # type: List[List[np.ndarray]]

        pbar = tqdm.tqdm()
        self.actor_critic.eval()

        iteration = 0
        while (
            len(stats_episodes) < self.config.ENV.TEST_EPISODE_COUNT
            and self.envs.num_envs > 0
        ):

            # [!!] Show more fine-grained progress. THOR is slow!
            pbar.update()

            # [!!] Show episodes collected so far
            if iteration%self.config.ENV.NUM_STEPS == 0:
                print (f'Iter: {iteration}')
                self.print_stats()
            iteration += 1 

            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = self.batch_obs(observations, self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(
                rewards, dtype=torch.float, device=self.device
            ).unsqueeze(1)

            current_episode_reward += rewards

            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):

                if (
                    next_episodes[i]['scene_id'],
                    next_episodes[i]['episode_id'],
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:

                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i])
                    )
                    current_episode_reward[i] = 0

                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[
                        (
                            current_episodes[i]['scene_id'],
                            current_episodes[i]['episode_id'],
                        )
                    ] = episode_stats

                    # [!!] save episode data 
                    self.save_episode(infos[i]['traj_masks'])

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )


        self.envs.close()