in ss_baselines/savi/ppo/policy.py [0:0]
def __init__(self, observation_space, hidden_size, goal_sensor_uuid, extra_rgb=False, use_mlp_state_encoder=False):
super().__init__()
self.goal_sensor_uuid = goal_sensor_uuid
self._hidden_size = hidden_size
self._audiogoal = False
self._pointgoal = False
self._n_pointgoal = 0
self._label = 'category' in observation_space.spaces
# for goal descriptors
self._use_label_belief = False
self._use_location_belief = False
self._use_mlp_state_encoder = use_mlp_state_encoder
if DUAL_GOAL_DELIMITER in self.goal_sensor_uuid:
goal1_uuid, goal2_uuid = self.goal_sensor_uuid.split(DUAL_GOAL_DELIMITER)
self._audiogoal = self._pointgoal = True
self._n_pointgoal = observation_space.spaces[goal1_uuid].shape[0]
else:
if 'pointgoal_with_gps_compass' == self.goal_sensor_uuid:
self._pointgoal = True
self._n_pointgoal = observation_space.spaces[self.goal_sensor_uuid].shape[0]
else:
self._audiogoal = True
self.visual_encoder = VisualCNN(observation_space, hidden_size, extra_rgb)
if self._audiogoal:
if 'audiogoal' in self.goal_sensor_uuid:
audiogoal_sensor = 'audiogoal'
elif 'spectrogram' in self.goal_sensor_uuid:
audiogoal_sensor = 'spectrogram'
self.audio_encoder = AudioCNN(observation_space, hidden_size, audiogoal_sensor)
rnn_input_size = (0 if self.is_blind else self._hidden_size) + \
(self._n_pointgoal if self._pointgoal else 0) + \
(self._hidden_size if self._audiogoal else 0) + \
(observation_space.spaces['category'].shape[0] if self._label else 0) + \
(observation_space.spaces[CategoryBelief.cls_uuid].shape[0] if self._use_label_belief else 0) + \
(observation_space.spaces[LocationBelief.cls_uuid].shape[0] if self._use_location_belief else 0)
if not self._use_mlp_state_encoder:
self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size)
else:
self.state_encoder = nn.Linear(rnn_input_size, self._hidden_size)
if not self.visual_encoder.is_blind:
summary(self.visual_encoder.cnn, self.visual_encoder.input_shape, device='cpu')
if self._audiogoal:
audio_shape = observation_space.spaces[audiogoal_sensor].shape
summary(self.audio_encoder.cnn, (audio_shape[2], audio_shape[0], audio_shape[1]), device='cpu')
self.train()