in habitat_baselines/rl/ppo/ppo_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training DD/PPO.
Returns:
None
"""
self._init_train()
count_checkpoints = 0
prev_time = 0
lr_scheduler = LambdaLR(
optimizer=self.agent.optimizer,
lr_lambda=lambda x: 1 - self.percent_done(),
)
resume_state = load_resume_state(self.config)
if resume_state is not None:
self.agent.load_state_dict(resume_state["state_dict"])
self.agent.optimizer.load_state_dict(resume_state["optim_state"])
lr_scheduler.load_state_dict(resume_state["lr_sched_state"])
requeue_stats = resume_state["requeue_stats"]
self.env_time = requeue_stats["env_time"]
self.pth_time = requeue_stats["pth_time"]
self.num_steps_done = requeue_stats["num_steps_done"]
self.num_updates_done = requeue_stats["num_updates_done"]
self._last_checkpoint_percent = requeue_stats[
"_last_checkpoint_percent"
]
count_checkpoints = requeue_stats["count_checkpoints"]
prev_time = requeue_stats["prev_time"]
self.running_episode_stats = requeue_stats["running_episode_stats"]
self.window_episode_stats.update(
requeue_stats["window_episode_stats"]
)
ppo_cfg = self.config.RL.PPO
with (
TensorboardWriter( # type: ignore
self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
)
if rank0_only()
else contextlib.suppress()
) as writer:
while not self.is_done():
profiling_wrapper.on_start_step()
profiling_wrapper.range_push("train update")
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * (
1 - self.percent_done()
)
if rank0_only() and self._should_save_resume_state():
requeue_stats = dict(
env_time=self.env_time,
pth_time=self.pth_time,
count_checkpoints=count_checkpoints,
num_steps_done=self.num_steps_done,
num_updates_done=self.num_updates_done,
_last_checkpoint_percent=self._last_checkpoint_percent,
prev_time=(time.time() - self.t_start) + prev_time,
running_episode_stats=self.running_episode_stats,
window_episode_stats=dict(self.window_episode_stats),
)
save_resume_state(
dict(
state_dict=self.agent.state_dict(),
optim_state=self.agent.optimizer.state_dict(),
lr_sched_state=lr_scheduler.state_dict(),
config=self.config,
requeue_stats=requeue_stats,
),
self.config,
)
if EXIT.is_set():
profiling_wrapper.range_pop() # train update
self.envs.close()
requeue_job()
return
self.agent.eval()
count_steps_delta = 0
profiling_wrapper.range_push("rollouts loop")
profiling_wrapper.range_push("_collect_rollout_step")
for buffer_index in range(self._nbuffers):
self._compute_actions_and_step_envs(buffer_index)
for step in range(ppo_cfg.num_steps):
is_last_step = (
self.should_end_early(step + 1)
or (step + 1) == ppo_cfg.num_steps
)
for buffer_index in range(self._nbuffers):
count_steps_delta += self._collect_environment_result(
buffer_index
)
if (buffer_index + 1) == self._nbuffers:
profiling_wrapper.range_pop() # _collect_rollout_step
if not is_last_step:
if (buffer_index + 1) == self._nbuffers:
profiling_wrapper.range_push(
"_collect_rollout_step"
)
self._compute_actions_and_step_envs(buffer_index)
if is_last_step:
break
profiling_wrapper.range_pop() # rollouts loop
if self._is_distributed:
self.num_rollouts_done_store.add("num_done", 1)
(
value_loss,
action_loss,
dist_entropy,
) = self._update_agent()
if ppo_cfg.use_linear_lr_decay:
lr_scheduler.step() # type: ignore
self.num_updates_done += 1
losses = self._coalesce_post_step(
dict(
value_loss=value_loss,
action_loss=action_loss,
entropy=dist_entropy,
),
count_steps_delta,
)
self._training_log(writer, losses, prev_time)
# checkpoint model
if rank0_only() and self.should_checkpoint():
self.save_checkpoint(
f"ckpt.{count_checkpoints}.pth",
dict(
step=self.num_steps_done,
wall_time=(time.time() - self.t_start) + prev_time,
),
)
count_checkpoints += 1
profiling_wrapper.range_pop() # train update
self.envs.close()