in habitat_baselines/rl/ppo/ppo_trainer.py [0:0]
def _init_train(self):
resume_state = load_resume_state(self.config)
if resume_state is not None:
self.config: Config = resume_state["config"]
self.using_velocity_ctrl = (
self.config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS
) == ["VELOCITY_CONTROL"]
if self.config.RL.DDPPO.force_distributed:
self._is_distributed = True
if is_slurm_batch_job():
add_signal_handlers()
if self._is_distributed:
local_rank, tcp_store = init_distrib_slurm(
self.config.RL.DDPPO.distrib_backend
)
if rank0_only():
logger.info(
"Initialized DD-PPO with {} workers".format(
torch.distributed.get_world_size()
)
)
self.config.defrost()
self.config.TORCH_GPU_ID = local_rank
self.config.SIMULATOR_GPU_ID = local_rank
# Multiply by the number of simulators to make sure they also get unique seeds
self.config.TASK_CONFIG.SEED += (
torch.distributed.get_rank() * self.config.NUM_ENVIRONMENTS
)
self.config.freeze()
random.seed(self.config.TASK_CONFIG.SEED)
np.random.seed(self.config.TASK_CONFIG.SEED)
torch.manual_seed(self.config.TASK_CONFIG.SEED)
self.num_rollouts_done_store = torch.distributed.PrefixStore(
"rollout_tracker", tcp_store
)
self.num_rollouts_done_store.set("num_done", "0")
if rank0_only() and self.config.VERBOSE:
logger.info(f"config: {self.config}")
profiling_wrapper.configure(
capture_start_step=self.config.PROFILING.CAPTURE_START_STEP,
num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE,
)
self._init_envs()
if self.using_velocity_ctrl:
self.policy_action_space = self.envs.action_spaces[0][
"VELOCITY_CONTROL"
]
action_shape = (2,)
discrete_actions = False
else:
self.policy_action_space = self.envs.action_spaces[0]
action_shape = None
discrete_actions = True
ppo_cfg = self.config.RL.PPO
if torch.cuda.is_available():
self.device = torch.device("cuda", self.config.TORCH_GPU_ID)
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
if rank0_only() and not os.path.isdir(self.config.CHECKPOINT_FOLDER):
os.makedirs(self.config.CHECKPOINT_FOLDER)
self._setup_actor_critic_agent(ppo_cfg)
if self._is_distributed:
self.agent.init_distributed(find_unused_params=True) # type: ignore
logger.info(
"agent number of parameters: {}".format(
sum(param.numel() for param in self.agent.parameters())
)
)
obs_space = self.obs_space
if self._static_encoder:
self._encoder = self.actor_critic.net.visual_encoder
obs_space = spaces.Dict(
{
"visual_features": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=self._encoder.output_shape,
dtype=np.float32,
),
**obs_space.spaces,
}
)
self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1
self.rollouts = RolloutStorage(
ppo_cfg.num_steps,
self.envs.num_envs,
obs_space,
self.policy_action_space,
ppo_cfg.hidden_size,
num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
is_double_buffered=ppo_cfg.use_double_buffered_sampler,
action_shape=action_shape,
discrete_actions=discrete_actions,
)
self.rollouts.to(self.device)
observations = self.envs.reset()
batch = batch_obs(
observations, device=self.device, cache=self._obs_batching_cache
)
batch = apply_obs_transforms_batch(batch, self.obs_transforms) # type: ignore
if self._static_encoder:
with torch.no_grad():
batch["visual_features"] = self._encoder(batch)
self.rollouts.buffers["observations"][0] = batch # type: ignore
self.current_episode_reward = torch.zeros(self.envs.num_envs, 1)
self.running_episode_stats = dict(
count=torch.zeros(self.envs.num_envs, 1),
reward=torch.zeros(self.envs.num_envs, 1),
)
self.window_episode_stats = defaultdict(
lambda: deque(maxlen=ppo_cfg.reward_window_size)
)
self.env_time = 0.0
self.pth_time = 0.0
self.t_start = time.time()