in ss_baselines/savi/ddppo/algo/ddppo_trainer.py [0:0]
def _setup_actor_critic_agent(self, ppo_cfg: Config, observation_space=None) -> None:
r"""Sets up actor critic and agent for DD-PPO.
Args:
ppo_cfg: config node with relevant params
Returns:
None
"""
logger.add_filehandler(self.config.LOG_FILE)
action_space = self.envs.action_spaces[0]
self.action_space = action_space
has_distractor_sound = self.config.TASK_CONFIG.SIMULATOR.AUDIO.HAS_DISTRACTOR_SOUND
if ppo_cfg.policy_type == 'rnn':
self.actor_critic = AudioNavBaselinePolicy(
observation_space=self.envs.observation_spaces[0],
action_space=self.action_space,
hidden_size=ppo_cfg.hidden_size,
goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID,
extra_rgb=self.config.EXTRA_RGB,
use_mlp_state_encoder=ppo_cfg.use_mlp_state_encoder
)
if ppo_cfg.use_belief_predictor:
belief_cfg = ppo_cfg.BELIEF_PREDICTOR
bp_class = BeliefPredictorDDP if belief_cfg.online_training else BeliefPredictor
self.belief_predictor = bp_class(belief_cfg, self.device, None, None,
ppo_cfg.hidden_size, self.envs.num_envs, has_distractor_sound
).to(device=self.device)
if belief_cfg.online_training:
params = list(self.belief_predictor.predictor.parameters())
if belief_cfg.train_encoder:
params += list(self.actor_critic.net.goal_encoder.parameters()) + \
list(self.actor_critic.net.visual_encoder.parameters()) + \
list(self.actor_critic.net.action_encoder.parameters())
self.belief_predictor.optimizer = torch.optim.Adam(params, lr=belief_cfg.lr)
self.belief_predictor.freeze_encoders()
elif ppo_cfg.policy_type == 'smt':
smt_cfg = ppo_cfg.SCENE_MEMORY_TRANSFORMER
belief_cfg = ppo_cfg.BELIEF_PREDICTOR
self.actor_critic = AudioNavSMTPolicy(
observation_space=self.envs.observation_spaces[0],
action_space=self.envs.action_spaces[0],
hidden_size=smt_cfg.hidden_size,
nhead=smt_cfg.nhead,
num_encoder_layers=smt_cfg.num_encoder_layers,
num_decoder_layers=smt_cfg.num_decoder_layers,
dropout=smt_cfg.dropout,
activation=smt_cfg.activation,
use_pretrained=smt_cfg.use_pretrained,
pretrained_path=smt_cfg.pretrained_path,
pretraining=smt_cfg.pretraining,
use_belief_encoding=smt_cfg.use_belief_encoding,
use_belief_as_goal=ppo_cfg.use_belief_predictor,
use_label_belief=belief_cfg.use_label_belief,
use_location_belief=belief_cfg.use_location_belief,
normalize_category_distribution=belief_cfg.normalize_category_distribution,
use_category_input=has_distractor_sound
)
if smt_cfg.freeze_encoders:
self._static_smt_encoder = True
self.actor_critic.net.freeze_encoders()
if ppo_cfg.use_belief_predictor:
smt = self.actor_critic.net.smt_state_encoder
bp_class = BeliefPredictorDDP if belief_cfg.online_training else BeliefPredictor
self.belief_predictor = bp_class(belief_cfg, self.device, smt._input_size, smt._pose_indices,
smt.hidden_state_size, self.envs.num_envs, has_distractor_sound
).to(device=self.device)
if belief_cfg.online_training:
params = list(self.belief_predictor.predictor.parameters())
if belief_cfg.train_encoder:
params += list(self.actor_critic.net.goal_encoder.parameters()) + \
list(self.actor_critic.net.visual_encoder.parameters()) + \
list(self.actor_critic.net.action_encoder.parameters())
self.belief_predictor.optimizer = torch.optim.Adam(params, lr=belief_cfg.lr)
self.belief_predictor.freeze_encoders()
else:
raise ValueError(f'Policy type {ppo_cfg.policy_type} is not defined!')
self.actor_critic.to(self.device)
if self.config.RL.DDPPO.pretrained:
# load weights for both actor critic and the encoder
pretrained_state = torch.load(self.config.RL.DDPPO.pretrained_weights, map_location="cpu")
self.actor_critic.load_state_dict(
{
k[len("actor_critic."):]: v
for k, v in pretrained_state["state_dict"].items()
if "actor_critic.net.visual_encoder" not in k and
"actor_critic.net.smt_state_encoder" not in k
},
strict=False
)
self.actor_critic.net.visual_encoder.rgb_encoder.load_state_dict(
{
k[len("actor_critic.net.visual_encoder.rgb_encoder."):]: v
for k, v in pretrained_state["state_dict"].items()
if "actor_critic.net.visual_encoder.rgb_encoder." in k
},
)
self.actor_critic.net.visual_encoder.depth_encoder.load_state_dict(
{
k[len("actor_critic.net.visual_encoder.depth_encoder."):]: v
for k, v in pretrained_state["state_dict"].items()
if "actor_critic.net.visual_encoder.depth_encoder." in k
},
)
if self.config.RL.DDPPO.reset_critic:
nn.init.orthogonal_(self.actor_critic.critic.fc.weight)
nn.init.constant_(self.actor_critic.critic.fc.bias, 0)
self.agent = DDPPO(
actor_critic=self.actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.entropy_coef,
lr=ppo_cfg.lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
use_normalized_advantage=ppo_cfg.use_normalized_advantage,
)