in occant_baselines/rl/occant_exp_trainer.py [0:0]
def _setup_actor_critic_agent(self, ppo_cfg: Config, ans_cfg: Config) -> None:
r"""Sets up actor critic and agent for PPO.
Args:
ppo_cfg: config node with relevant params
ans_cfg: config node for ActiveNeuralSLAM model
Returns:
None
"""
try:
os.mkdir(self.config.TENSORBOARD_DIR)
except:
pass
logger.add_filehandler(os.path.join(self.config.TENSORBOARD_DIR, "run.log"))
occ_cfg = ans_cfg.OCCUPANCY_ANTICIPATOR
mapper_cfg = ans_cfg.MAPPER
# Create occupancy anticipation model
occupancy_model = OccupancyAnticipator(occ_cfg)
occupancy_model = OccupancyAnticipationWrapper(
occupancy_model, mapper_cfg.map_size, (128, 128)
)
# Create ANS model
self.ans_net = ActiveNeuralSLAMExplorer(ans_cfg, occupancy_model)
self.mapper = self.ans_net.mapper
self.local_actor_critic = self.ans_net.local_policy
self.global_actor_critic = self.ans_net.global_policy
# Create depth projection model to estimate visible occupancy
self.depth_projection_net = DepthProjectionNet(
ans_cfg.OCCUPANCY_ANTICIPATOR.EGO_PROJECTION
)
# Set to device
self.mapper.to(self.device)
self.local_actor_critic.to(self.device)
self.global_actor_critic.to(self.device)
self.depth_projection_net.to(self.device)
# ============================== Create agents ================================
# Mapper agent
self.mapper_agent = MapUpdate(
self.mapper,
lr=mapper_cfg.lr,
eps=mapper_cfg.eps,
label_id=mapper_cfg.label_id,
max_grad_norm=mapper_cfg.max_grad_norm,
pose_loss_coef=mapper_cfg.pose_loss_coef,
occupancy_anticipator_type=ans_cfg.OCCUPANCY_ANTICIPATOR.type,
freeze_projection_unit=mapper_cfg.freeze_projection_unit,
num_update_batches=mapper_cfg.num_update_batches,
batch_size=mapper_cfg.map_batch_size,
mapper_rollouts=self.mapper_rollouts,
)
# Local policy
if ans_cfg.LOCAL_POLICY.use_heuristic_policy:
self.local_agent = None
elif ans_cfg.LOCAL_POLICY.learning_algorithm == "rl":
self.local_agent = PPO(
actor_critic=self.local_actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.local_entropy_coef,
lr=ppo_cfg.local_policy_lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
)
else:
self.local_agent = Imitation(
actor_critic=self.local_actor_critic,
lr=ppo_cfg.local_policy_lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
)
# Global policy
self.global_agent = PPO(
actor_critic=self.global_actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.entropy_coef,
lr=ppo_cfg.lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
)
if ans_cfg.model_path != "":
self.resume_checkpoint(ans_cfg.model_path)