in occant_baselines/rl/occant_exp_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training PPO.
Returns:
None
"""
self.envs = construct_envs(
self.config,
get_env_class(self.config.ENV_NAME),
devices=self._assign_devices(),
)
ppo_cfg = self.config.RL.PPO
ans_cfg = self.config.RL.ANS
mapper_cfg = self.config.RL.ANS.MAPPER
occ_cfg = self.config.RL.ANS.OCCUPANCY_ANTICIPATOR
self.device = (
torch.device("cuda", self.config.TORCH_GPU_ID)
if torch.cuda.is_available()
else torch.device("cpu")
)
if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self.mapper_rollouts = self._create_mapper_rollouts(ppo_cfg, ans_cfg)
self._setup_actor_critic_agent(ppo_cfg, ans_cfg)
logger.info(
"mapper_agent number of parameters: {}".format(
sum(param.numel() for param in self.mapper_agent.parameters())
)
)
logger.info(
"local_agent number of parameters: {}".format(
sum(param.numel() for param in self.local_agent.parameters())
)
)
logger.info(
"global_agent number of parameters: {}".format(
sum(param.numel() for param in self.global_agent.parameters())
)
)
mapper_rollouts = self.mapper_rollouts
global_rollouts = self._create_global_rollouts(ppo_cfg, ans_cfg)
local_rollouts = self._create_local_rollouts(ppo_cfg, ans_cfg)
global_rollouts.to(self.device)
local_rollouts.to(self.device)
# ===================== Create statistics buffers =====================
statistics_dict = {}
# Mapper statistics
statistics_dict["mapper"] = defaultdict(
lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
)
# Local policy statistics
local_episode_rewards = torch.zeros(self.envs.num_envs, 1)
statistics_dict["local_policy"] = defaultdict(
lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
)
window_local_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
window_local_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
# Global policy statistics
global_episode_rewards = torch.zeros(self.envs.num_envs, 1)
statistics_dict["global_policy"] = defaultdict(
lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
)
window_global_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
window_global_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
# Overall count statistics
episode_counts = torch.zeros(self.envs.num_envs, 1)
t_start = time.time()
env_time = 0
pth_time = 0
count_steps = 0
count_checkpoints = 0
# ==================== Measuring memory consumption ===================
total_memory_size = 0
print("=================== Mapper rollouts ======================")
for k, v in mapper_rollouts.observations.items():
mem = v.element_size() * v.nelement() * 1e-9
print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
total_memory_size += mem
print(f"Total memory: {total_memory_size:>10.4f} GB")
total_memory_size = 0
print("================== Local policy rollouts =====================")
for k, v in local_rollouts.observations.items():
mem = v.element_size() * v.nelement() * 1e-9
print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
total_memory_size += mem
print(f"Total memory: {total_memory_size:>10.4f} GB")
total_memory_size = 0
print("================== Global policy rollouts ====================")
for k, v in global_rollouts.observations.items():
mem = v.element_size() * v.nelement() * 1e-9
print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
total_memory_size += mem
print(f"Total memory: {total_memory_size:>10.4f} GB")
# Resume checkpoint if available
(
num_updates_start,
count_steps_start,
count_checkpoints,
) = self.resume_checkpoint()
count_steps = count_steps_start
imH, imW = ans_cfg.image_scale_hw
M = ans_cfg.overall_map_size
# ==================== Create state variables =================
state_estimates = {
# Agent's pose estimate
"pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device),
# Agent's map
"map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device),
"recurrent_hidden_states": torch.zeros(
1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size
).to(self.device),
"visited_states": torch.zeros(self.envs.num_envs, 1, M, M).to(self.device),
}
ground_truth_states = {
# To measure area seen
"visible_occupancy": torch.zeros(
self.envs.num_envs, 2, M, M, device=self.device
),
"pose": torch.zeros(self.envs.num_envs, 3, device=self.device),
"prev_global_reward_metric": torch.zeros(
self.envs.num_envs, 1, device=self.device
),
}
if (
ans_cfg.reward_type == "map_accuracy"
or ans_cfg.LOCAL_POLICY.learning_algorithm == "il"
):
ground_truth_states["environment_layout"] = torch.zeros(
self.envs.num_envs, 2, M, M
).to(self.device)
masks = torch.zeros(self.envs.num_envs, 1)
episode_step_count = torch.zeros(self.envs.num_envs, 1, device=self.device)
# ==================== Reset the environments =================
observations = self.envs.reset()
batch = self._prepare_batch(observations)
prev_batch = batch
# Update visible occupancy
ground_truth_states["visible_occupancy"] = self.mapper.ext_register_map(
ground_truth_states["visible_occupancy"],
rearrange(batch["ego_map_gt"], "b h w c -> b c h w"),
batch["pose_gt"],
)
ground_truth_states["pose"].copy_(batch["pose_gt"])
current_local_episode_reward = torch.zeros(self.envs.num_envs, 1)
current_global_episode_reward = torch.zeros(self.envs.num_envs, 1)
running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1),
local_reward=torch.zeros(self.envs.num_envs, 1),
global_reward=torch.zeros(self.envs.num_envs, 1),
)
window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
# Useful variables
NUM_MAPPER_STEPS = ans_cfg.MAPPER.num_mapper_steps
NUM_LOCAL_STEPS = ppo_cfg.num_local_steps
NUM_GLOBAL_STEPS = ppo_cfg.num_global_steps
GLOBAL_UPDATE_INTERVAL = NUM_GLOBAL_STEPS * ans_cfg.goal_interval
NUM_GLOBAL_UPDATES_PER_EPISODE = self.config.T_EXP // GLOBAL_UPDATE_INTERVAL
NUM_GLOBAL_UPDATES = (
self.config.NUM_EPISODES
* NUM_GLOBAL_UPDATES_PER_EPISODE
// self.config.NUM_PROCESSES
)
# Sanity checks
assert (
NUM_MAPPER_STEPS % NUM_LOCAL_STEPS == 0
), "Mapper steps must be a multiple of global steps interval"
assert (
NUM_LOCAL_STEPS == ans_cfg.goal_interval
), "Local steps must be same as subgoal sampling interval"
with TensorboardWriter(
self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
) as writer:
for update in range(num_updates_start, NUM_GLOBAL_UPDATES):
for step in range(NUM_GLOBAL_STEPS):
(
delta_pth_time,
delta_env_time,
delta_steps,
prev_batch,
batch,
state_estimates,
ground_truth_states,
) = self._collect_rollout_step(
batch,
prev_batch,
episode_step_count,
state_estimates,
ground_truth_states,
masks,
mapper_rollouts,
local_rollouts,
global_rollouts,
current_local_episode_reward,
current_global_episode_reward,
running_episode_stats,
statistics_dict,
)
pth_time += delta_pth_time
env_time += delta_env_time
count_steps += delta_steps
# Useful flags
FROZEN_MAPPER = (
True
if mapper_cfg.ignore_pose_estimator
and (
occ_cfg.type in self.frozen_mapper_types
or mapper_cfg.freeze_projection_unit
)
else False
)
UPDATE_MAPPER_FLAG = (
True
if episode_step_count[0].item() % NUM_MAPPER_STEPS == 0
else False
)
UPDATE_LOCAL_FLAG = True
# ------------------------ update mapper --------------------------
if UPDATE_MAPPER_FLAG:
(
delta_pth_time,
update_metrics_mapper,
) = self._update_mapper_agent(mapper_rollouts)
for k, v in update_metrics_mapper.items():
statistics_dict["mapper"][k].append(v)
pth_time += delta_pth_time
# -------------------- update local policy ------------------------
if UPDATE_LOCAL_FLAG:
delta_pth_time = self._supplementary_rollout_update(
batch,
prev_batch,
episode_step_count,
state_estimates,
ground_truth_states,
masks,
local_rollouts,
global_rollouts,
update_option="local",
)
# Sanity check
assert local_rollouts.step == local_rollouts.num_steps
pth_time += delta_pth_time
(
delta_pth_time,
update_metrics_local,
) = self._update_local_agent(local_rollouts)
for k, v in update_metrics_local.items():
statistics_dict["local_policy"][k].append(v)
# -------------------------- log statistics -----------------------
for k, v in statistics_dict.items():
logger.info(
"=========== {:20s} ============".format(k + " stats")
)
for kp, vp in v.items():
if len(vp) > 0:
writer.add_scalar(f"{k}/{kp}", np.mean(vp), count_steps)
logger.info(f"{kp:25s}: {np.mean(vp).item():10.5f}")
for k, v in running_episode_stats.items():
window_episode_stats[k].append(v.clone())
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in window_episode_stats.items()
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"local_reward",
deltas["local_reward"] / deltas["count"],
count_steps,
)
writer.add_scalar(
"global_reward",
deltas["global_reward"] / deltas["count"],
count_steps,
)
fps = (count_steps - count_steps_start) / (time.time() - t_start)
writer.add_scalar("fps", fps, count_steps)
if update > 0:
logger.info("update: {}\tfps: {:.3f}\t".format(update, fps))
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(update, env_time, pth_time, count_steps)
)
logger.info(
"Average window size: {} {}".format(
len(window_episode_stats["count"]),
" ".join(
"{}: {:.3f}".format(k, v / deltas["count"])
for k, v in deltas.items()
if k != "count"
),
)
)
pth_time += delta_pth_time
# At episode termination, manually set masks to zeros.
if episode_step_count[0].item() == self.config.T_EXP:
masks.fill_(0)
# -------------------- update global policy -----------------------
self._supplementary_rollout_update(
batch,
prev_batch,
episode_step_count,
state_estimates,
ground_truth_states,
masks,
local_rollouts,
global_rollouts,
update_option="global",
)
# Sanity check
assert global_rollouts.step == NUM_GLOBAL_STEPS
(delta_pth_time, update_metrics_global,) = self._update_global_agent(
global_rollouts
)
for k, v in update_metrics_global.items():
statistics_dict["global_policy"][k].append(v)
pth_time += delta_pth_time
# checkpoint model
if update % self.config.CHECKPOINT_INTERVAL == 0:
self.save_checkpoint(
f"ckpt.{count_checkpoints}.pth",
dict(step=count_steps, update=update),
)
count_checkpoints += 1
# Manually enforce episode termination criterion
if episode_step_count[0].item() == self.config.T_EXP:
# Update episode rewards
running_episode_stats["local_reward"] += (
1 - masks
) * current_local_episode_reward
running_episode_stats["global_reward"] += (
1 - masks
) * current_global_episode_reward
running_episode_stats["count"] += 1 - masks
current_local_episode_reward *= masks
current_global_episode_reward *= masks
# Measure accumulative error in pose estimates
pose_estimation_metrics = measure_pose_estimation_performance(
state_estimates["pose_estimates"], ground_truth_states["pose"]
)
for k, v in pose_estimation_metrics.items():
statistics_dict["mapper"]["episode_" + k].append(v)
observations = self.envs.reset()
batch = self._prepare_batch(observations)
prev_batch = batch
# Reset episode step counter
episode_step_count.fill_(0)
# Reset states
for k in ground_truth_states.keys():
ground_truth_states[k].fill_(0)
for k in state_estimates.keys():
state_estimates[k].fill_(0)
# Update visible occupancy
ground_truth_states[
"visible_occupancy"
] = self.mapper.ext_register_map(
ground_truth_states["visible_occupancy"],
rearrange(batch["ego_map_gt"], "b h w c -> b c h w"),
batch["pose_gt"],
)
ground_truth_states["pose"].copy_(batch["pose_gt"])
self.envs.close()