in ss_baselines/savi/ppo/ppo_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training PPO.
Returns:
None
"""
logger.info(f"config: {self.config}")
random.seed(self.config.SEED)
np.random.seed(self.config.SEED)
torch.manual_seed(self.config.SEED)
# add_signal_handlers()
self.envs = construct_envs(
self.config, get_env_class(self.config.ENV_NAME), workers_ignore_signals=True
)
ppo_cfg = self.config.RL.PPO
self.device = (
torch.device("cuda", self.config.TORCH_GPU_ID)
if torch.cuda.is_available()
else torch.device("cpu")
)
if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self._setup_actor_critic_agent(ppo_cfg)
logger.info(
"agent number of parameters: {}".format(
sum(param.numel() for param in self.agent.parameters())
)
)
if ppo_cfg.use_external_memory:
memory_dim = self.actor_critic.net.memory_dim
else:
memory_dim = None
rollouts = RolloutStorage(
ppo_cfg.num_steps,
self.envs.num_envs,
self.envs.observation_spaces[0],
self.envs.action_spaces[0],
ppo_cfg.hidden_size,
ppo_cfg.use_external_memory,
ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps,
ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size,
memory_dim,
)
rollouts.to(self.device)
observations = self.envs.reset()
batch = batch_obs(observations)
if self.config.RL.PPO.use_belief_predictor:
self.belief_predictor.update(batch, None)
for sensor in rollouts.observations:
rollouts.observations[sensor][0].copy_(batch[sensor])
# batch and observations may contain shared PyTorch CUDA
# tensors. We must explicitly clear them here otherwise
# they will be kept in memory for the entire duration of training!
batch = None
observations = None
current_episode_reward = torch.zeros(self.envs.num_envs, 1)
running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1),
reward=torch.zeros(self.envs.num_envs, 1),
)
window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
t_start = time.time()
env_time = 0
pth_time = 0
count_steps = 0
count_checkpoints = 0
start_update = 0
prev_time = 0
lr_scheduler = LambdaLR(
optimizer=self.agent.optimizer,
lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
)
interrupted_state = load_interrupted_state(model_dir=self.config.MODEL_DIR)
if interrupted_state is not None:
self.agent.load_state_dict(interrupted_state["state_dict"])
self.agent.optimizer.load_state_dict(
interrupted_state["optimizer_state"]
)
lr_scheduler.load_state_dict(interrupted_state["lr_scheduler_state"])
requeue_stats = interrupted_state["requeue_stats"]
env_time = requeue_stats["env_time"]
pth_time = requeue_stats["pth_time"]
count_steps = requeue_stats["count_steps"]
count_checkpoints = requeue_stats["count_checkpoints"]
start_update = requeue_stats["start_update"]
prev_time = requeue_stats["prev_time"]
with TensorboardWriter(
self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
) as writer:
for update in range(start_update, self.config.NUM_UPDATES):
if ppo_cfg.use_linear_lr_decay:
lr_scheduler.step()
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
update, self.config.NUM_UPDATES
)
if EXIT.is_set():
self.envs.close()
if REQUEUE.is_set():
requeue_stats = dict(
env_time=env_time,
pth_time=pth_time,
count_steps=count_steps,
count_checkpoints=count_checkpoints,
start_update=update,
prev_time=(time.time() - t_start) + prev_time,
)
save_interrupted_state(
dict(
state_dict=self.agent.state_dict(),
optimizer_state=self.agent.optimizer.state_dict(),
lr_scheduler_state=lr_scheduler.state_dict(),
config=self.config,
requeue_stats=requeue_stats,
),
model_dir=self.config.MODEL_DIR
)
requeue_job()
return
for step in range(ppo_cfg.num_steps):
delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
rollouts,
current_episode_reward,
running_episode_stats
)
pth_time += delta_pth_time
env_time += delta_env_time
count_steps += delta_steps
delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
ppo_cfg, rollouts
)
pth_time += delta_pth_time
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in window_episode_stats.items()
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"Metrics/reward", deltas["reward"] / deltas["count"], count_steps
)
# Check to see if there are any metrics
# that haven't been logged yet
metrics = {
k: v / deltas["count"]
for k, v in deltas.items()
if k not in {"reward", "count"}
}
if len(metrics) > 0:
# writer.add_scalars("metrics", metrics, count_steps)
for metric, value in metrics.items():
writer.add_scalar(f"Metrics/{metric}", value, count_steps)
writer.add_scalar("Policy/value_loss", value_loss, count_steps)
writer.add_scalar("Policy/policy_loss", action_loss, count_steps)
writer.add_scalar("Policy/entropy_loss", dist_entropy, count_steps)
writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps)
# log stats
if update > 0 and update % self.config.LOG_INTERVAL == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update, count_steps / (time.time() - t_start)
)
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
update, env_time, pth_time, count_steps
)
)
logger.info(
"Average window size: {} {}".format(
len(window_episode_stats["count"]),
" ".join(
"{}: {:.3f}".format(k, v / deltas["count"])
for k, v in deltas.items()
if k != "count"
),
)
)
# checkpoint model
if update % self.config.CHECKPOINT_INTERVAL == 0:
self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
count_checkpoints += 1
self.envs.close()