in ss_baselines/savi/ddppo/algo/ddppo_trainer.py [0:0]
def train(self) -> None:
r"""Main method for DD-PPO.
Returns:
None
"""
self.local_rank, tcp_store = init_distrib_slurm(
self.config.RL.DDPPO.distrib_backend
)
add_signal_handlers()
# Stores the number of workers that have finished their rollout
num_rollouts_done_store = distrib.PrefixStore(
"rollout_tracker", tcp_store
)
num_rollouts_done_store.set("num_done", "0")
self.world_rank = distrib.get_rank()
self.world_size = distrib.get_world_size()
self.config.defrost()
self.config.TORCH_GPU_ID = self.local_rank
self.config.SIMULATOR_GPU_ID = self.local_rank
# Multiply by the number of simulators to make sure they also get unique seeds
self.config.TASK_CONFIG.SEED += (
self.world_rank * self.config.NUM_PROCESSES
)
self.config.freeze()
random.seed(self.config.TASK_CONFIG.SEED)
np.random.seed(self.config.TASK_CONFIG.SEED)
torch.manual_seed(self.config.TASK_CONFIG.SEED)
if torch.cuda.is_available():
self.device = torch.device("cuda", self.local_rank)
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
self.envs = construct_envs(
self.config, get_env_class(self.config.ENV_NAME)
)
ppo_cfg = self.config.RL.PPO
if (
not os.path.isdir(self.config.CHECKPOINT_FOLDER)
and self.world_rank == 0
):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self._setup_actor_critic_agent(ppo_cfg)
self.agent.init_distributed(find_unused_params=True)
if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
self.belief_predictor.init_distributed(find_unused_params=True)
if self.world_rank == 0:
logger.info(
"agent number of trainable parameters: {}".format(
sum(
param.numel()
for param in self.agent.parameters()
if param.requires_grad
)
)
)
if ppo_cfg.use_belief_predictor:
logger.info(
"belief predictor number of trainable parameters: {}".format(
sum(
param.numel()
for param in self.belief_predictor.parameters()
if param.requires_grad
)
)
)
logger.info(f"config: {self.config}")
observations = self.envs.reset()
batch = batch_obs(observations, device=self.device)
obs_space = self.envs.observation_spaces[0]
if ppo_cfg.use_external_memory:
memory_dim = self.actor_critic.net.memory_dim
else:
memory_dim = None
rollouts = RolloutStorage(
ppo_cfg.num_steps,
self.envs.num_envs,
obs_space,
self.action_space,
ppo_cfg.hidden_size,
ppo_cfg.use_external_memory,
ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps,
ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size,
memory_dim,
num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
)
rollouts.to(self.device)
if self.config.RL.PPO.use_belief_predictor:
self.belief_predictor.update(batch, None)
for sensor in rollouts.observations:
rollouts.observations[sensor][0].copy_(batch[sensor])
# batch and observations may contain shared PyTorch CUDA
# tensors. We must explicitly clear them here otherwise
# they will be kept in memory for the entire duration of training!
batch = None
observations = None
current_episode_reward = torch.zeros(
self.envs.num_envs, 1, device=self.device
)
running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1, device=self.device),
reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
)
window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
t_start = time.time()
env_time = 0
pth_time = 0
count_steps = 0
count_checkpoints = 0
start_update = 0
prev_time = 0
lr_scheduler = LambdaLR(
optimizer=self.agent.optimizer,
lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
)
# Try to resume at previous checkpoint (independent of interrupted states)
count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint()
count_steps = count_steps_start
interrupted_state = load_interrupted_state()
if interrupted_state is not None:
self.agent.load_state_dict(interrupted_state["state_dict"])
if self.config.RL.PPO.use_belief_predictor:
self.belief_predictor.load_state_dict(interrupted_state["belief_predictor"])
self.agent.optimizer.load_state_dict(
interrupted_state["optim_state"]
)
lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])
requeue_stats = interrupted_state["requeue_stats"]
env_time = requeue_stats["env_time"]
pth_time = requeue_stats["pth_time"]
count_steps = requeue_stats["count_steps"]
count_checkpoints = requeue_stats["count_checkpoints"]
start_update = requeue_stats["start_update"]
prev_time = requeue_stats["prev_time"]
with (
TensorboardWriter(
self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
)
if self.world_rank == 0
else contextlib.suppress()
) as writer:
for update in range(start_update, self.config.NUM_UPDATES):
if ppo_cfg.use_linear_lr_decay:
lr_scheduler.step()
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
update, self.config.NUM_UPDATES
)
if EXIT.is_set():
self.envs.close()
if REQUEUE.is_set() and self.world_rank == 0:
requeue_stats = dict(
env_time=env_time,
pth_time=pth_time,
count_steps=count_steps,
count_checkpoints=count_checkpoints,
start_update=update,
prev_time=(time.time() - t_start) + prev_time,
)
state_dict = dict(
state_dict=self.agent.state_dict(),
optim_state=self.agent.optimizer.state_dict(),
lr_sched_state=lr_scheduler.state_dict(),
config=self.config,
requeue_stats=requeue_stats,
)
if self.config.RL.PPO.use_belief_predictor:
state_dict['belief_predictor'] = self.belief_predictor.state_dict()
save_interrupted_state(state_dict)
requeue_job()
return
count_steps_delta = 0
self.agent.eval()
if self.config.RL.PPO.use_belief_predictor:
self.belief_predictor.eval()
for step in range(ppo_cfg.num_steps):
(
delta_pth_time,
delta_env_time,
delta_steps,
) = self._collect_rollout_step(
rollouts, current_episode_reward, running_episode_stats
)
pth_time += delta_pth_time
env_time += delta_env_time
count_steps_delta += delta_steps
# This is where the preemption of workers happens. If a
# worker detects it will be a straggler, it preempts itself!
if (
step
>= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
) and int(num_rollouts_done_store.get("num_done")) > (
self.config.RL.DDPPO.sync_frac * self.world_size
):
break
num_rollouts_done_store.add("num_done", 1)
self.agent.train()
if self.config.RL.PPO.use_belief_predictor:
self.belief_predictor.train()
self.belief_predictor.set_eval_encoders()
if self._static_smt_encoder:
self.actor_critic.net.set_eval_encoders()
if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
location_predictor_loss, prediction_accuracy = self.train_belief_predictor(rollouts)
else:
location_predictor_loss = 0
prediction_accuracy = 0
(
delta_pth_time,
value_loss,
action_loss,
dist_entropy,
) = self._update_agent(ppo_cfg, rollouts)
pth_time += delta_pth_time
stats_ordering = list(sorted(running_episode_stats.keys()))
stats = torch.stack(
[running_episode_stats[k] for k in stats_ordering], 0
)
distrib.all_reduce(stats)
for i, k in enumerate(stats_ordering):
window_episode_stats[k].append(stats[i].clone())
stats = torch.tensor(
[value_loss, action_loss, dist_entropy, location_predictor_loss, prediction_accuracy, count_steps_delta],
device=self.device,
)
distrib.all_reduce(stats)
count_steps += stats[5].item()
if self.world_rank == 0:
num_rollouts_done_store.set("num_done", "0")
losses = [
stats[0].item() / self.world_size,
stats[1].item() / self.world_size,
stats[2].item() / self.world_size,
stats[3].item() / self.world_size,
stats[4].item() / self.world_size,
]
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in window_episode_stats.items()
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"Metrics/reward", deltas["reward"] / deltas["count"], count_steps
)
# Check to see if there are any metrics
# that haven't been logged yet
metrics = {
k: v / deltas["count"]
for k, v in deltas.items()
if k not in {"reward", "count"}
}
if len(metrics) > 0:
for metric, value in metrics.items():
writer.add_scalar(f"Metrics/{metric}", value, count_steps)
writer.add_scalar("Policy/value_loss", losses[0], count_steps)
writer.add_scalar("Policy/policy_loss", losses[1], count_steps)
writer.add_scalar("Policy/entropy_loss", losses[2], count_steps)
writer.add_scalar("Policy/predictor_loss", losses[3], count_steps)
writer.add_scalar("Policy/predictor_accuracy", losses[4], count_steps)
writer.add_scalar('Policy/learning_rate', lr_scheduler.get_lr()[0], count_steps)
# log stats
if update > 0 and update % self.config.LOG_INTERVAL == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update,
(count_steps - count_steps_start)
/ ((time.time() - t_start) + prev_time),
)
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
update, env_time, pth_time, count_steps
)
)
logger.info(
"Average window size: {} {}".format(
len(window_episode_stats["count"]),
" ".join(
"{}: {:.3f}".format(k, v / deltas["count"])
for k, v in deltas.items()
if k != "count"
),
)
)
# checkpoint model
if update % self.config.CHECKPOINT_INTERVAL == 0:
self.save_checkpoint(
f"ckpt.{count_checkpoints}.pth",
dict(step=count_steps),
)
count_checkpoints += 1
self.envs.close()