in exploring_exploration/utils/eval.py [0:0]
def evaluate_tdn_astar(models, envs, config, device, visualize_policy=False):
# =============== Evaluation configs ======================
num_steps_exp = config["num_steps_exp"]
num_steps_nav = config["num_steps_nav"]
num_processes = 1
num_eval_episodes = config["num_eval_episodes"]
env_name = config["env_name"]
actor_type = config["actor_type"]
encoder_type = config["encoder_type"]
feat_shape_sim = config["feat_shape_sim"]
use_action_embedding = config["use_action_embedding"]
use_collision_embedding = config["use_collision_embedding"]
vis_save_dir = config["vis_save_dir"]
if actor_type == "forward":
forward_action_id = config["forward_action_id"]
elif actor_type == "forward-plus":
forward_action_id = config["forward_action_id"]
turn_action_id = config["turn_action_id"]
elif actor_type == "frontier":
assert num_processes == 1
if "avd" in env_name:
action_space = {"forward": 2, "left": 0, "right": 1, "stop": 3}
else:
action_space = {"forward": 0, "left": 1, "right": 2, "stop": 3}
occ_map_scale = config["occ_map_scale"]
max_time_per_target = config["max_time_per_target"]
frontier_agent = FrontierAgent(
action_space,
env_name,
occ_map_scale,
show_animation=False,
max_time_per_target=max_time_per_target,
)
use_policy = (
actor_type != "random"
and actor_type != "oracle"
and actor_type != "forward"
and actor_type != "forward-plus"
and actor_type != "frontier"
)
# =============== Models ======================
nav_policy = models["nav_policy"]
if use_policy:
encoder = models["encoder"]
actor_critic = models["actor_critic"]
# Set to evaluation mode
if use_policy:
encoder.eval()
actor_critic.eval()
tbwriter = TensorboardWriter(log_dir=vis_save_dir)
# =============== Gather evaluation info ======================
episode_environment_statistics = []
exp_area_covered = []
exp_collisions = []
nav_error_all = []
s_score_all = []
spl_score_all = []
def get_obs(obs):
obs_im = process_image(obs["im"])
if encoder_type == "rgb+map":
obs_lm = process_image(obs["coarse_occupancy"])
obs_sm = process_image(obs["fine_occupancy"])
else:
obs_lm = None
obs_sm = None
return obs_im, obs_sm, obs_lm
# =============== Evaluate over predefined number of episodes ======================
obs = envs.reset()
num_eval_batches = (num_eval_episodes // num_processes) + 1
for neval in range(num_eval_batches):
# Processing environment inputs
obs_im, obs_sm, obs_lm = get_obs(obs)
obs_collns = obs["collisions"]
if actor_type == "frontier":
delta_ego = torch.zeros((num_processes, 3)).to(device)
frontier_agent.reset()
if use_policy:
recurrent_hidden_states = torch.zeros(num_processes, feat_shape_sim[0]).to(
device
)
masks = torch.zeros(num_processes, 1).to(device)
nav_policy.reset()
prev_action = torch.zeros(num_processes, 1).long().to(device)
prev_collision = obs_collns
obs_odometer = torch.zeros(num_processes, 4).to(device)
per_proc_collisions = [0.0 for _ in range(num_processes)]
# =================================================================
# ==================== Perform exploration ========================
# =================================================================
for step in range(num_steps_exp):
if use_policy:
encoder_inputs = [obs_im]
if encoder_type == "rgb+map":
encoder_inputs += [obs_sm, obs_lm]
with torch.no_grad():
policy_feats = encoder(*encoder_inputs)
policy_inputs = {"features": policy_feats}
if use_action_embedding:
policy_inputs["actions"] = prev_action
if use_collision_embedding:
policy_inputs["collisions"] = prev_collision.long()
policy_outputs = actor_critic.act(
policy_inputs,
recurrent_hidden_states,
masks,
deterministic=False,
)
_, action, _, recurrent_hidden_states = policy_outputs
elif actor_type == "oracle":
action = obs["oracle_action"].long()
elif actor_type == "random":
action = torch.randint(
0, envs.action_space.n, (num_processes, 1)
).long()
elif actor_type == "forward":
action = torch.Tensor(np.ones((num_processes, 1)) * forward_action_id)
action = action.long()
elif actor_type == "forward-plus":
action = torch.Tensor(np.ones((num_processes, 1)) * forward_action_id)
collision_mask = prev_collision > 0
action[collision_mask] = turn_action_id
action = action.long()
elif actor_type == "frontier":
# This assumes that num_processes = 1
occ_map = obs["highres_coarse_occupancy"][0].cpu().numpy()
occ_map = occ_map.transpose(1, 2, 0)
occ_map = np.ascontiguousarray(occ_map)
occ_map = occ_map.astype(np.uint8)
action = frontier_agent.act(
occ_map, delta_ego[0].cpu().numpy(), prev_collision[0].item()
)
action = torch.Tensor([[action]]).long()
obs, reward, done, infos = envs.step(action)
# Processing environment inputs
obs_im, obs_sm, obs_lm = get_obs(obs)
obs_collns = obs["collisions"]
obs_odometer_curr = process_odometer(obs["delta"])
if actor_type == "frontier":
delta_ego = compute_egocentric_coors(
obs_odometer_curr, obs_odometer, occ_map_scale,
) # (N, 3) --- (dx_ego, dy_ego, dt_ego)
# Always set masks to 1 (does not matter for now)
masks = torch.FloatTensor([[1.0] for _ in range(num_processes)]).to(device)
obs_odometer = obs_odometer + obs_odometer_curr
# This must not reach done = True
assert done[0] == False
# Update collisions metric
for pr in range(num_processes):
per_proc_collisions[pr] += obs_collns[pr, 0].item()
prev_collision = obs_collns
prev_action = action
# Verifying correctness
if step == num_steps_exp - 1:
assert infos[0]["finished_exploration"]
elif step < num_steps_exp - 1:
assert not infos[0]["finished_exploration"]
exploration_topdown_map = infos[0]["topdown_map"]
# Update Exploration statistics
for pr in range(num_processes):
episode_environment_statistics.append(infos[pr]["environment_statistics"])
exp_area_covered.append(infos[pr]["seen_area"])
exp_collisions.append(per_proc_collisions[pr])
# =================================================================
# ===================== Navigation evaluation =====================
# =================================================================
# gather statistics for visualization
per_proc_rgb = [[] for pr in range(num_processes)]
per_proc_depth = [[] for pr in range(num_processes)]
per_proc_fine_occ = [[] for pr in range(num_processes)]
per_proc_coarse_occ = [[] for pr in range(num_processes)]
per_proc_topdown_map = [[] for pr in range(num_processes)]
per_proc_planner_vis = [[] for pr in range(num_processes)]
per_proc_gt_topdown_map = [[] for pr in range(num_processes)]
per_proc_initial_planner_vis = [[] for pr in range(num_processes)]
per_proc_exploration_topdown_map = [[] for pr in range(num_processes)]
WIDTH, HEIGHT = 300, 300
nav_policy.reset()
initial_planning_vis = None
for t in range(num_steps_nav):
# Processing environment inputs
obs_highres_coarse_occ = torch_to_np(obs["highres_coarse_occupancy"][0])
if t == 0:
topdown_map = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
else:
topdown_map = infos[0]["topdown_map"]
goal_x = int(obs["target_grid_loc"][0, 0].item())
goal_y = int(obs["target_grid_loc"][0, 1].item())
coarse_occ_orig = np.flip(obs_highres_coarse_occ, axis=2)
coarse_occ_orig = np.ascontiguousarray(coarse_occ_orig)
action = nav_policy.act(
coarse_occ_orig, (goal_x, goal_y), obs["collisions"][0, 0].item()
)
if action == 3:
logging.info("=====> STOP action called!")
actions = torch.Tensor([[action]])
obs, reward, done, infos = envs.step(actions)
if visualize_policy:
if t == 0:
initial_planning_vis = np.flip(
nav_policy.planning_visualization, axis=2
)
for pr in range(num_processes):
per_proc_rgb[pr].append(torch_to_np(obs["im"][pr]))
if "habitat" in env_name:
per_proc_depth[pr].append(
torch_to_np_depth(obs["depth"][pr] * 10000.0)
)
else:
per_proc_depth[pr].append(torch_to_np_depth(obs["depth"][pr]))
per_proc_fine_occ[pr].append(torch_to_np(obs["fine_occupancy"][pr]))
per_proc_coarse_occ[pr].append(
torch_to_np(obs["highres_coarse_occupancy"][pr])
)
per_proc_topdown_map[pr].append(
np.flip(infos[pr]["topdown_map"], axis=2)
)
per_proc_planner_vis[pr].append(
np.flip(nav_policy.planning_visualization, axis=2)
)
per_proc_initial_planner_vis[pr].append(initial_planning_vis)
per_proc_exploration_topdown_map[pr].append(
np.flip(exploration_topdown_map, axis=2)
)
if done[0] or action == 3:
nav_error_all.append(infos[0]["nav_error"])
spl_score_all.append(infos[0]["spl"])
s_score_all.append(infos[0]["success_rate"])
break
if t == num_steps_nav - 1 and not done[0]:
raise AssertionError("done not being called at end of episode!")
# Write the episode data to tensorboard
if visualize_policy:
proc_fn = lambda x: np.ascontiguousarray(
np.flip(np.concatenate(x, axis=1), axis=2)
)
for pr in range(num_processes):
rgb_data = per_proc_rgb[pr]
depth_data = per_proc_depth[pr]
fine_occ_data = per_proc_fine_occ[pr]
coarse_occ_data = per_proc_coarse_occ[pr]
topdown_map_data = per_proc_topdown_map[pr]
planner_vis_data = per_proc_planner_vis[pr]
final_topdown_map_data = [
topdown_map_data[-1] for _ in range(len(topdown_map_data))
]
initial_planner_vis_data = per_proc_initial_planner_vis[pr]
exploration_topdown_map_data = per_proc_exploration_topdown_map[pr]
per_frame_data_proc = zip(
rgb_data,
coarse_occ_data,
topdown_map_data,
planner_vis_data,
final_topdown_map_data,
initial_planner_vis_data,
exploration_topdown_map_data,
)
video_frames = [
proc_fn([cv2.resize(d, (WIDTH, HEIGHT)) for d in per_frame_data])
for per_frame_data in per_frame_data_proc
]
tbwriter.add_video_from_np_images(
"Episode_{:05d}".format(neval), 0, video_frames, fps=4
)
logging.info(
"===========> Episode done: SPL: {:.3f}, SR: {:.3f}, Nav Err: {:.3f}, Neval: {}".format(
spl_score_all[-1], s_score_all[-1], nav_error_all[-1], neval
)
)
envs.close()
# Fill in per-episode statistics
total_episodes = len(nav_error_all)
per_episode_statistics = []
for nep in range(total_episodes):
per_episode_metrics = {
"time_step": num_steps_exp,
"nav_error": nav_error_all[nep],
"success_rate": s_score_all[nep],
"spl": spl_score_all[nep],
"exploration_area_covered": exp_area_covered[nep],
"exploration_collisions": exp_collisions[nep],
"environment_statistics": episode_environment_statistics[nep],
}
per_episode_statistics.append(per_episode_metrics)
metrics = {}
metrics["nav_error"] = np.mean(nav_error_all)
metrics["spl"] = np.mean(spl_score_all)
metrics["success_rate"] = np.mean(s_score_all)
logging.info(
"======= Evaluating for {} episodes ========".format(
num_eval_batches * num_processes
)
)
for k, v in metrics.items():
logging.info("{}: {:.3f}".format(k, v))
return metrics, per_episode_statistics