def evaluate_tdn_astar()

in exploring_exploration/utils/eval.py [0:0]


def evaluate_tdn_astar(models, envs, config, device, visualize_policy=False):
    # =============== Evaluation configs ======================
    num_steps_exp = config["num_steps_exp"]
    num_steps_nav = config["num_steps_nav"]
    num_processes = 1
    num_eval_episodes = config["num_eval_episodes"]
    env_name = config["env_name"]
    actor_type = config["actor_type"]
    encoder_type = config["encoder_type"]
    feat_shape_sim = config["feat_shape_sim"]
    use_action_embedding = config["use_action_embedding"]
    use_collision_embedding = config["use_collision_embedding"]
    vis_save_dir = config["vis_save_dir"]
    if actor_type == "forward":
        forward_action_id = config["forward_action_id"]
    elif actor_type == "forward-plus":
        forward_action_id = config["forward_action_id"]
        turn_action_id = config["turn_action_id"]
    elif actor_type == "frontier":
        assert num_processes == 1
        if "avd" in env_name:
            action_space = {"forward": 2, "left": 0, "right": 1, "stop": 3}
        else:
            action_space = {"forward": 0, "left": 1, "right": 2, "stop": 3}
        occ_map_scale = config["occ_map_scale"]
        max_time_per_target = config["max_time_per_target"]
        frontier_agent = FrontierAgent(
            action_space,
            env_name,
            occ_map_scale,
            show_animation=False,
            max_time_per_target=max_time_per_target,
        )

    use_policy = (
        actor_type != "random"
        and actor_type != "oracle"
        and actor_type != "forward"
        and actor_type != "forward-plus"
        and actor_type != "frontier"
    )

    # =============== Models ======================
    nav_policy = models["nav_policy"]
    if use_policy:
        encoder = models["encoder"]
        actor_critic = models["actor_critic"]

    # Set to evaluation mode
    if use_policy:
        encoder.eval()
        actor_critic.eval()

    tbwriter = TensorboardWriter(log_dir=vis_save_dir)

    # =============== Gather evaluation info  ======================
    episode_environment_statistics = []
    exp_area_covered = []
    exp_collisions = []
    nav_error_all = []
    s_score_all = []
    spl_score_all = []

    def get_obs(obs):
        obs_im = process_image(obs["im"])
        if encoder_type == "rgb+map":
            obs_lm = process_image(obs["coarse_occupancy"])
            obs_sm = process_image(obs["fine_occupancy"])
        else:
            obs_lm = None
            obs_sm = None
        return obs_im, obs_sm, obs_lm

    # =============== Evaluate over predefined number of episodes  ======================
    obs = envs.reset()
    num_eval_batches = (num_eval_episodes // num_processes) + 1
    for neval in range(num_eval_batches):
        # Processing environment inputs
        obs_im, obs_sm, obs_lm = get_obs(obs)
        obs_collns = obs["collisions"]
        if actor_type == "frontier":
            delta_ego = torch.zeros((num_processes, 3)).to(device)
            frontier_agent.reset()

        if use_policy:
            recurrent_hidden_states = torch.zeros(num_processes, feat_shape_sim[0]).to(
                device
            )
            masks = torch.zeros(num_processes, 1).to(device)

        nav_policy.reset()

        prev_action = torch.zeros(num_processes, 1).long().to(device)
        prev_collision = obs_collns
        obs_odometer = torch.zeros(num_processes, 4).to(device)
        per_proc_collisions = [0.0 for _ in range(num_processes)]

        # =================================================================
        # ==================== Perform exploration ========================
        # =================================================================
        for step in range(num_steps_exp):
            if use_policy:
                encoder_inputs = [obs_im]
                if encoder_type == "rgb+map":
                    encoder_inputs += [obs_sm, obs_lm]
                with torch.no_grad():
                    policy_feats = encoder(*encoder_inputs)
                    policy_inputs = {"features": policy_feats}
                    if use_action_embedding:
                        policy_inputs["actions"] = prev_action
                    if use_collision_embedding:
                        policy_inputs["collisions"] = prev_collision.long()

                    policy_outputs = actor_critic.act(
                        policy_inputs,
                        recurrent_hidden_states,
                        masks,
                        deterministic=False,
                    )
                    _, action, _, recurrent_hidden_states = policy_outputs
            elif actor_type == "oracle":
                action = obs["oracle_action"].long()
            elif actor_type == "random":
                action = torch.randint(
                    0, envs.action_space.n, (num_processes, 1)
                ).long()
            elif actor_type == "forward":
                action = torch.Tensor(np.ones((num_processes, 1)) * forward_action_id)
                action = action.long()
            elif actor_type == "forward-plus":
                action = torch.Tensor(np.ones((num_processes, 1)) * forward_action_id)
                collision_mask = prev_collision > 0
                action[collision_mask] = turn_action_id
                action = action.long()
            elif actor_type == "frontier":
                # This assumes that num_processes = 1
                occ_map = obs["highres_coarse_occupancy"][0].cpu().numpy()
                occ_map = occ_map.transpose(1, 2, 0)
                occ_map = np.ascontiguousarray(occ_map)
                occ_map = occ_map.astype(np.uint8)
                action = frontier_agent.act(
                    occ_map, delta_ego[0].cpu().numpy(), prev_collision[0].item()
                )
                action = torch.Tensor([[action]]).long()

            obs, reward, done, infos = envs.step(action)
            # Processing environment inputs
            obs_im, obs_sm, obs_lm = get_obs(obs)
            obs_collns = obs["collisions"]

            obs_odometer_curr = process_odometer(obs["delta"])
            if actor_type == "frontier":
                delta_ego = compute_egocentric_coors(
                    obs_odometer_curr, obs_odometer, occ_map_scale,
                )  # (N, 3) --- (dx_ego, dy_ego, dt_ego)

            # Always set masks to 1 (does not matter for now)
            masks = torch.FloatTensor([[1.0] for _ in range(num_processes)]).to(device)
            obs_odometer = obs_odometer + obs_odometer_curr

            # This must not reach done = True
            assert done[0] == False

            # Update collisions metric
            for pr in range(num_processes):
                per_proc_collisions[pr] += obs_collns[pr, 0].item()

            prev_collision = obs_collns
            prev_action = action

            # Verifying correctness
            if step == num_steps_exp - 1:
                assert infos[0]["finished_exploration"]
            elif step < num_steps_exp - 1:
                assert not infos[0]["finished_exploration"]
                exploration_topdown_map = infos[0]["topdown_map"]
        # Update Exploration statistics
        for pr in range(num_processes):
            episode_environment_statistics.append(infos[pr]["environment_statistics"])
            exp_area_covered.append(infos[pr]["seen_area"])
            exp_collisions.append(per_proc_collisions[pr])

        # =================================================================
        # ===================== Navigation evaluation =====================
        # =================================================================
        # gather statistics for visualization
        per_proc_rgb = [[] for pr in range(num_processes)]
        per_proc_depth = [[] for pr in range(num_processes)]
        per_proc_fine_occ = [[] for pr in range(num_processes)]
        per_proc_coarse_occ = [[] for pr in range(num_processes)]
        per_proc_topdown_map = [[] for pr in range(num_processes)]
        per_proc_planner_vis = [[] for pr in range(num_processes)]
        per_proc_gt_topdown_map = [[] for pr in range(num_processes)]
        per_proc_initial_planner_vis = [[] for pr in range(num_processes)]
        per_proc_exploration_topdown_map = [[] for pr in range(num_processes)]

        WIDTH, HEIGHT = 300, 300

        nav_policy.reset()

        initial_planning_vis = None
        for t in range(num_steps_nav):
            # Processing environment inputs
            obs_highres_coarse_occ = torch_to_np(obs["highres_coarse_occupancy"][0])
            if t == 0:
                topdown_map = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
            else:
                topdown_map = infos[0]["topdown_map"]
            goal_x = int(obs["target_grid_loc"][0, 0].item())
            goal_y = int(obs["target_grid_loc"][0, 1].item())
            coarse_occ_orig = np.flip(obs_highres_coarse_occ, axis=2)
            coarse_occ_orig = np.ascontiguousarray(coarse_occ_orig)

            action = nav_policy.act(
                coarse_occ_orig, (goal_x, goal_y), obs["collisions"][0, 0].item()
            )
            if action == 3:
                logging.info("=====> STOP action called!")
            actions = torch.Tensor([[action]])

            obs, reward, done, infos = envs.step(actions)

            if visualize_policy:
                if t == 0:
                    initial_planning_vis = np.flip(
                        nav_policy.planning_visualization, axis=2
                    )
                for pr in range(num_processes):
                    per_proc_rgb[pr].append(torch_to_np(obs["im"][pr]))
                    if "habitat" in env_name:
                        per_proc_depth[pr].append(
                            torch_to_np_depth(obs["depth"][pr] * 10000.0)
                        )
                    else:
                        per_proc_depth[pr].append(torch_to_np_depth(obs["depth"][pr]))
                    per_proc_fine_occ[pr].append(torch_to_np(obs["fine_occupancy"][pr]))
                    per_proc_coarse_occ[pr].append(
                        torch_to_np(obs["highres_coarse_occupancy"][pr])
                    )
                    per_proc_topdown_map[pr].append(
                        np.flip(infos[pr]["topdown_map"], axis=2)
                    )
                    per_proc_planner_vis[pr].append(
                        np.flip(nav_policy.planning_visualization, axis=2)
                    )
                    per_proc_initial_planner_vis[pr].append(initial_planning_vis)
                    per_proc_exploration_topdown_map[pr].append(
                        np.flip(exploration_topdown_map, axis=2)
                    )

            if done[0] or action == 3:
                nav_error_all.append(infos[0]["nav_error"])
                spl_score_all.append(infos[0]["spl"])
                s_score_all.append(infos[0]["success_rate"])
                break

            if t == num_steps_nav - 1 and not done[0]:
                raise AssertionError("done not being called at end of episode!")

        # Write the episode data to tensorboard
        if visualize_policy:
            proc_fn = lambda x: np.ascontiguousarray(
                np.flip(np.concatenate(x, axis=1), axis=2)
            )
            for pr in range(num_processes):
                rgb_data = per_proc_rgb[pr]
                depth_data = per_proc_depth[pr]
                fine_occ_data = per_proc_fine_occ[pr]
                coarse_occ_data = per_proc_coarse_occ[pr]
                topdown_map_data = per_proc_topdown_map[pr]
                planner_vis_data = per_proc_planner_vis[pr]
                final_topdown_map_data = [
                    topdown_map_data[-1] for _ in range(len(topdown_map_data))
                ]
                initial_planner_vis_data = per_proc_initial_planner_vis[pr]
                exploration_topdown_map_data = per_proc_exploration_topdown_map[pr]

                per_frame_data_proc = zip(
                    rgb_data,
                    coarse_occ_data,
                    topdown_map_data,
                    planner_vis_data,
                    final_topdown_map_data,
                    initial_planner_vis_data,
                    exploration_topdown_map_data,
                )

                video_frames = [
                    proc_fn([cv2.resize(d, (WIDTH, HEIGHT)) for d in per_frame_data])
                    for per_frame_data in per_frame_data_proc
                ]
                tbwriter.add_video_from_np_images(
                    "Episode_{:05d}".format(neval), 0, video_frames, fps=4
                )

        logging.info(
            "===========> Episode done: SPL: {:.3f}, SR: {:.3f}, Nav Err: {:.3f}, Neval: {}".format(
                spl_score_all[-1], s_score_all[-1], nav_error_all[-1], neval
            )
        )

    envs.close()

    # Fill in per-episode statistics
    total_episodes = len(nav_error_all)
    per_episode_statistics = []
    for nep in range(total_episodes):
        per_episode_metrics = {
            "time_step": num_steps_exp,
            "nav_error": nav_error_all[nep],
            "success_rate": s_score_all[nep],
            "spl": spl_score_all[nep],
            "exploration_area_covered": exp_area_covered[nep],
            "exploration_collisions": exp_collisions[nep],
            "environment_statistics": episode_environment_statistics[nep],
        }
        per_episode_statistics.append(per_episode_metrics)

    metrics = {}
    metrics["nav_error"] = np.mean(nav_error_all)
    metrics["spl"] = np.mean(spl_score_all)
    metrics["success_rate"] = np.mean(s_score_all)

    logging.info(
        "======= Evaluating for {} episodes ========".format(
            num_eval_batches * num_processes
        )
    )
    for k, v in metrics.items():
        logging.info("{}: {:.3f}".format(k, v))

    return metrics, per_episode_statistics