def train()

in occant_baselines/rl/occant_exp_trainer.py [0:0]


    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """

        self.envs = construct_envs(
            self.config,
            get_env_class(self.config.ENV_NAME),
            devices=self._assign_devices(),
        )

        ppo_cfg = self.config.RL.PPO
        ans_cfg = self.config.RL.ANS
        mapper_cfg = self.config.RL.ANS.MAPPER
        occ_cfg = self.config.RL.ANS.OCCUPANCY_ANTICIPATOR
        self.device = (
            torch.device("cuda", self.config.TORCH_GPU_ID)
            if torch.cuda.is_available()
            else torch.device("cpu")
        )
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self.mapper_rollouts = self._create_mapper_rollouts(ppo_cfg, ans_cfg)
        self._setup_actor_critic_agent(ppo_cfg, ans_cfg)
        logger.info(
            "mapper_agent number of parameters: {}".format(
                sum(param.numel() for param in self.mapper_agent.parameters())
            )
        )
        logger.info(
            "local_agent number of parameters: {}".format(
                sum(param.numel() for param in self.local_agent.parameters())
            )
        )
        logger.info(
            "global_agent number of parameters: {}".format(
                sum(param.numel() for param in self.global_agent.parameters())
            )
        )
        mapper_rollouts = self.mapper_rollouts
        global_rollouts = self._create_global_rollouts(ppo_cfg, ans_cfg)
        local_rollouts = self._create_local_rollouts(ppo_cfg, ans_cfg)
        global_rollouts.to(self.device)
        local_rollouts.to(self.device)
        # ===================== Create statistics buffers =====================
        statistics_dict = {}
        # Mapper statistics
        statistics_dict["mapper"] = defaultdict(
            lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
        )
        # Local policy statistics
        local_episode_rewards = torch.zeros(self.envs.num_envs, 1)
        statistics_dict["local_policy"] = defaultdict(
            lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
        )
        window_local_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_local_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
        # Global policy statistics
        global_episode_rewards = torch.zeros(self.envs.num_envs, 1)
        statistics_dict["global_policy"] = defaultdict(
            lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
        )
        window_global_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_global_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
        # Overall count statistics
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        # ==================== Measuring memory consumption ===================
        total_memory_size = 0
        print("=================== Mapper rollouts ======================")
        for k, v in mapper_rollouts.observations.items():
            mem = v.element_size() * v.nelement() * 1e-9
            print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
            total_memory_size += mem
        print(f"Total memory: {total_memory_size:>10.4f} GB")

        total_memory_size = 0
        print("================== Local policy rollouts =====================")
        for k, v in local_rollouts.observations.items():
            mem = v.element_size() * v.nelement() * 1e-9
            print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
            total_memory_size += mem
        print(f"Total memory: {total_memory_size:>10.4f} GB")

        total_memory_size = 0
        print("================== Global policy rollouts ====================")
        for k, v in global_rollouts.observations.items():
            mem = v.element_size() * v.nelement() * 1e-9
            print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
            total_memory_size += mem
        print(f"Total memory: {total_memory_size:>10.4f} GB")
        # Resume checkpoint if available
        (
            num_updates_start,
            count_steps_start,
            count_checkpoints,
        ) = self.resume_checkpoint()
        count_steps = count_steps_start

        imH, imW = ans_cfg.image_scale_hw
        M = ans_cfg.overall_map_size
        # ==================== Create state variables =================
        state_estimates = {
            # Agent's pose estimate
            "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device),
            # Agent's map
            "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device),
            "recurrent_hidden_states": torch.zeros(
                1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size
            ).to(self.device),
            "visited_states": torch.zeros(self.envs.num_envs, 1, M, M).to(self.device),
        }
        ground_truth_states = {
            # To measure area seen
            "visible_occupancy": torch.zeros(
                self.envs.num_envs, 2, M, M, device=self.device
            ),
            "pose": torch.zeros(self.envs.num_envs, 3, device=self.device),
            "prev_global_reward_metric": torch.zeros(
                self.envs.num_envs, 1, device=self.device
            ),
        }
        if (
            ans_cfg.reward_type == "map_accuracy"
            or ans_cfg.LOCAL_POLICY.learning_algorithm == "il"
        ):
            ground_truth_states["environment_layout"] = torch.zeros(
                self.envs.num_envs, 2, M, M
            ).to(self.device)
        masks = torch.zeros(self.envs.num_envs, 1)
        episode_step_count = torch.zeros(self.envs.num_envs, 1, device=self.device)

        # ==================== Reset the environments =================
        observations = self.envs.reset()
        batch = self._prepare_batch(observations)
        prev_batch = batch
        # Update visible occupancy
        ground_truth_states["visible_occupancy"] = self.mapper.ext_register_map(
            ground_truth_states["visible_occupancy"],
            rearrange(batch["ego_map_gt"], "b h w c -> b c h w"),
            batch["pose_gt"],
        )
        ground_truth_states["pose"].copy_(batch["pose_gt"])

        current_local_episode_reward = torch.zeros(self.envs.num_envs, 1)
        current_global_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            local_reward=torch.zeros(self.envs.num_envs, 1),
            global_reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        # Useful variables
        NUM_MAPPER_STEPS = ans_cfg.MAPPER.num_mapper_steps
        NUM_LOCAL_STEPS = ppo_cfg.num_local_steps
        NUM_GLOBAL_STEPS = ppo_cfg.num_global_steps
        GLOBAL_UPDATE_INTERVAL = NUM_GLOBAL_STEPS * ans_cfg.goal_interval
        NUM_GLOBAL_UPDATES_PER_EPISODE = self.config.T_EXP // GLOBAL_UPDATE_INTERVAL
        NUM_GLOBAL_UPDATES = (
            self.config.NUM_EPISODES
            * NUM_GLOBAL_UPDATES_PER_EPISODE
            // self.config.NUM_PROCESSES
        )
        # Sanity checks
        assert (
            NUM_MAPPER_STEPS % NUM_LOCAL_STEPS == 0
        ), "Mapper steps must be a multiple of global steps interval"
        assert (
            NUM_LOCAL_STEPS == ans_cfg.goal_interval
        ), "Local steps must be same as subgoal sampling interval"
        with TensorboardWriter(
            self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:
            for update in range(num_updates_start, NUM_GLOBAL_UPDATES):
                for step in range(NUM_GLOBAL_STEPS):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                        prev_batch,
                        batch,
                        state_estimates,
                        ground_truth_states,
                    ) = self._collect_rollout_step(
                        batch,
                        prev_batch,
                        episode_step_count,
                        state_estimates,
                        ground_truth_states,
                        masks,
                        mapper_rollouts,
                        local_rollouts,
                        global_rollouts,
                        current_local_episode_reward,
                        current_global_episode_reward,
                        running_episode_stats,
                        statistics_dict,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                    # Useful flags
                    FROZEN_MAPPER = (
                        True
                        if mapper_cfg.ignore_pose_estimator
                        and (
                            occ_cfg.type in self.frozen_mapper_types
                            or mapper_cfg.freeze_projection_unit
                        )
                        else False
                    )
                    UPDATE_MAPPER_FLAG = (
                        True
                        if episode_step_count[0].item() % NUM_MAPPER_STEPS == 0
                        else False
                    )
                    UPDATE_LOCAL_FLAG = True

                    # ------------------------ update mapper --------------------------
                    if UPDATE_MAPPER_FLAG:
                        (
                            delta_pth_time,
                            update_metrics_mapper,
                        ) = self._update_mapper_agent(mapper_rollouts)

                        for k, v in update_metrics_mapper.items():
                            statistics_dict["mapper"][k].append(v)

                    pth_time += delta_pth_time

                    # -------------------- update local policy ------------------------
                    if UPDATE_LOCAL_FLAG:
                        delta_pth_time = self._supplementary_rollout_update(
                            batch,
                            prev_batch,
                            episode_step_count,
                            state_estimates,
                            ground_truth_states,
                            masks,
                            local_rollouts,
                            global_rollouts,
                            update_option="local",
                        )

                        # Sanity check
                        assert local_rollouts.step == local_rollouts.num_steps

                        pth_time += delta_pth_time
                        (
                            delta_pth_time,
                            update_metrics_local,
                        ) = self._update_local_agent(local_rollouts)

                        for k, v in update_metrics_local.items():
                            statistics_dict["local_policy"][k].append(v)

                    # -------------------------- log statistics -----------------------
                    for k, v in statistics_dict.items():
                        logger.info(
                            "=========== {:20s} ============".format(k + " stats")
                        )
                        for kp, vp in v.items():
                            if len(vp) > 0:
                                writer.add_scalar(f"{k}/{kp}", np.mean(vp), count_steps)
                                logger.info(f"{kp:25s}: {np.mean(vp).item():10.5f}")

                    for k, v in running_episode_stats.items():
                        window_episode_stats[k].append(v.clone())

                    deltas = {
                        k: (
                            (v[-1] - v[0]).sum().item()
                            if len(v) > 1
                            else v[0].sum().item()
                        )
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "local_reward",
                        deltas["local_reward"] / deltas["count"],
                        count_steps,
                    )
                    writer.add_scalar(
                        "global_reward",
                        deltas["global_reward"] / deltas["count"],
                        count_steps,
                    )
                    fps = (count_steps - count_steps_start) / (time.time() - t_start)
                    writer.add_scalar("fps", fps, count_steps)

                    if update > 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(update, fps))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time, count_steps)
                        )

                        logger.info(
                            "Average window size: {}  {}".format(
                                len(window_episode_stats["count"]),
                                "  ".join(
                                    "{}: {:.3f}".format(k, v / deltas["count"])
                                    for k, v in deltas.items()
                                    if k != "count"
                                ),
                            )
                        )

                    pth_time += delta_pth_time

                # At episode termination, manually set masks to zeros.
                if episode_step_count[0].item() == self.config.T_EXP:
                    masks.fill_(0)

                # -------------------- update global policy -----------------------
                self._supplementary_rollout_update(
                    batch,
                    prev_batch,
                    episode_step_count,
                    state_estimates,
                    ground_truth_states,
                    masks,
                    local_rollouts,
                    global_rollouts,
                    update_option="global",
                )

                # Sanity check
                assert global_rollouts.step == NUM_GLOBAL_STEPS

                (delta_pth_time, update_metrics_global,) = self._update_global_agent(
                    global_rollouts
                )

                for k, v in update_metrics_global.items():
                    statistics_dict["global_policy"][k].append(v)

                pth_time += delta_pth_time

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth",
                        dict(step=count_steps, update=update),
                    )
                    count_checkpoints += 1

                # Manually enforce episode termination criterion
                if episode_step_count[0].item() == self.config.T_EXP:

                    # Update episode rewards
                    running_episode_stats["local_reward"] += (
                        1 - masks
                    ) * current_local_episode_reward
                    running_episode_stats["global_reward"] += (
                        1 - masks
                    ) * current_global_episode_reward
                    running_episode_stats["count"] += 1 - masks

                    current_local_episode_reward *= masks
                    current_global_episode_reward *= masks

                    # Measure accumulative error in pose estimates
                    pose_estimation_metrics = measure_pose_estimation_performance(
                        state_estimates["pose_estimates"], ground_truth_states["pose"]
                    )
                    for k, v in pose_estimation_metrics.items():
                        statistics_dict["mapper"]["episode_" + k].append(v)

                    observations = self.envs.reset()
                    batch = self._prepare_batch(observations)
                    prev_batch = batch
                    # Reset episode step counter
                    episode_step_count.fill_(0)
                    # Reset states
                    for k in ground_truth_states.keys():
                        ground_truth_states[k].fill_(0)
                    for k in state_estimates.keys():
                        state_estimates[k].fill_(0)
                    # Update visible occupancy
                    ground_truth_states[
                        "visible_occupancy"
                    ] = self.mapper.ext_register_map(
                        ground_truth_states["visible_occupancy"],
                        rearrange(batch["ego_map_gt"], "b h w c -> b c h w"),
                        batch["pose_gt"],
                    )
                    ground_truth_states["pose"].copy_(batch["pose_gt"])

            self.envs.close()