in rl/ppo/ppo_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training PPO.
Returns:
None
"""
self.envs = construct_envs(
self.config, get_env_class(self.config.ENV.ENV_NAME)
)
ppo_cfg = self.config.RL.PPO
self.device = (
torch.device("cuda", self.config.TORCH_GPU_ID)
if torch.cuda.is_available()
else torch.device("cpu")
)
if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self._setup_actor_critic_agent(ppo_cfg)
logger.info(
"agent number of parameters: {}".format(
sum(param.numel() for param in self.agent.parameters())
)
)
# [!!] Allow subclasses to create modified rollout storages
rollouts = self.create_rollout_storage(ppo_cfg)
rollouts.to(self.device)
observations = self.envs.reset()
batch = self.batch_obs(observations, device=self.device)
for sensor in rollouts.observations:
rollouts.observations[sensor][0].copy_(batch[sensor])
# batch and observations may contain shared PyTorch CUDA
# tensors. We must explicitly clear them here otherwise
# they will be kept in memory for the entire duration of training!
batch = None
observations = None
current_episode_reward = torch.zeros(self.envs.num_envs, 1)
running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1),
reward=torch.zeros(self.envs.num_envs, 1),
)
window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
t_start = time.time()
env_time = 0
pth_time = 0
count_steps = 0
count_checkpoints = 0
lr_scheduler = LambdaLR(
optimizer=self.agent.optimizer,
lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
)
with TensorboardWriter(
self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
) as writer:
for update in range(self.config.NUM_UPDATES):
if ppo_cfg.use_linear_lr_decay:
lr_scheduler.step()
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
update, self.config.NUM_UPDATES
)
for step in tqdm.tqdm(range(ppo_cfg.num_steps)): # [!!] Add tqdm
(
delta_pth_time,
delta_env_time,
delta_steps,
) = self._collect_rollout_step(
rollouts, current_episode_reward, running_episode_stats
)
pth_time += delta_pth_time
env_time += delta_env_time
count_steps += delta_steps
(
delta_pth_time,
value_loss,
action_loss,
dist_entropy,
) = self._update_agent(ppo_cfg, rollouts)
pth_time += delta_pth_time
for k, v in running_episode_stats.items():
window_episode_stats[k].append(v.clone())
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in window_episode_stats.items()
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"reward", deltas["reward"] / deltas["count"], count_steps
)
# [!!] Write policy/value/dist_entropy losses directly
writer.add_scalar('policy_loss', action_loss, count_steps)
writer.add_scalar('value_loss', value_loss, count_steps)
writer.add_scalar('dist_entropy', dist_entropy, count_steps)
# # Check to see if there are any metrics
# # that haven't been logged yet
# metrics = {
# k: v / deltas["count"]
# for k, v in deltas.items()
# if k not in {"reward", "count"}
# }
# if len(metrics) > 0:
# writer.add_scalars("metrics", metrics, count_steps)
# losses = [value_loss, action_loss]
# writer.add_scalars(
# "losses",
# {k: l for l, k in zip(losses, ["value", "policy"])},
# count_steps,
# )
# log stats
if update > 0 and update % self.config.LOG_INTERVAL == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update, count_steps / (time.time() - t_start)
)
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
update, env_time, pth_time, count_steps
)
)
logger.info(
"Average window size: {} {}".format(
len(window_episode_stats["count"]),
" ".join(
"{}: {:.3f}".format(k, v / deltas["count"])
for k, v in deltas.items()
if k != "count"
),
)
)
# checkpoint model
if update % self.config.CHECKPOINT_INTERVAL == 0:
self.save_checkpoint(
f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)
)
count_checkpoints += 1
self.envs.close()