in habitat_baselines/rl/ppo/ppo_trainer.py [0:0]
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
r"""Sets up actor critic and agent for PPO.
Args:
ppo_cfg: config node with relevant params
Returns:
None
"""
logger.add_filehandler(self.config.LOG_FILE)
policy = baseline_registry.get_policy(self.config.RL.POLICY.name)
observation_space = self.obs_space
self.obs_transforms = get_active_obs_transforms(self.config)
observation_space = apply_obs_transforms_obs_space(
observation_space, self.obs_transforms
)
self.actor_critic = policy.from_config(
self.config, observation_space, self.policy_action_space
)
self.obs_space = observation_space
self.actor_critic.to(self.device)
if (
self.config.RL.DDPPO.pretrained_encoder
or self.config.RL.DDPPO.pretrained
):
pretrained_state = torch.load(
self.config.RL.DDPPO.pretrained_weights, map_location="cpu"
)
if self.config.RL.DDPPO.pretrained:
self.actor_critic.load_state_dict(
{ # type: ignore
k[len("actor_critic.") :]: v
for k, v in pretrained_state["state_dict"].items()
}
)
elif self.config.RL.DDPPO.pretrained_encoder:
prefix = "actor_critic.net.visual_encoder."
self.actor_critic.net.visual_encoder.load_state_dict(
{
k[len(prefix) :]: v
for k, v in pretrained_state["state_dict"].items()
if k.startswith(prefix)
}
)
if not self.config.RL.DDPPO.train_encoder:
self._static_encoder = True
for param in self.actor_critic.net.visual_encoder.parameters():
param.requires_grad_(False)
if self.config.RL.DDPPO.reset_critic:
nn.init.orthogonal_(self.actor_critic.critic.fc.weight)
nn.init.constant_(self.actor_critic.critic.fc.bias, 0)
self.agent = (DDPPO if self._is_distributed else PPO)(
actor_critic=self.actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.entropy_coef,
lr=ppo_cfg.lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
use_normalized_advantage=ppo_cfg.use_normalized_advantage,
)