in ss_baselines/av_nav/ppo/policy.py [0:0]
def __init__(self, observation_space, hidden_size, goal_sensor_uuid, extra_rgb=False):
super().__init__()
self.goal_sensor_uuid = goal_sensor_uuid
self._hidden_size = hidden_size
self._audiogoal = False
self._pointgoal = False
self._n_pointgoal = 0
if DUAL_GOAL_DELIMITER in self.goal_sensor_uuid:
goal1_uuid, goal2_uuid = self.goal_sensor_uuid.split(DUAL_GOAL_DELIMITER)
self._audiogoal = self._pointgoal = True
self._n_pointgoal = observation_space.spaces[goal1_uuid].shape[0]
else:
if 'pointgoal_with_gps_compass' == self.goal_sensor_uuid:
self._pointgoal = True
self._n_pointgoal = observation_space.spaces[self.goal_sensor_uuid].shape[0]
else:
self._audiogoal = True
self.visual_encoder = VisualCNN(observation_space, hidden_size, extra_rgb)
if self._audiogoal:
if 'audiogoal' in self.goal_sensor_uuid:
audiogoal_sensor = 'audiogoal'
elif 'spectrogram' in self.goal_sensor_uuid:
audiogoal_sensor = 'spectrogram'
self.audio_encoder = AudioCNN(observation_space, hidden_size, audiogoal_sensor)
rnn_input_size = (0 if self.is_blind else self._hidden_size) + \
(self._n_pointgoal if self._pointgoal else 0) + (self._hidden_size if self._audiogoal else 0)
self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size)
if 'rgb' in observation_space.spaces and not extra_rgb:
rgb_shape = observation_space.spaces['rgb'].shape
summary(self.visual_encoder.cnn, (rgb_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu')
if 'depth' in observation_space.spaces:
depth_shape = observation_space.spaces['depth'].shape
summary(self.visual_encoder.cnn, (depth_shape[2], depth_shape[0], depth_shape[1]), device='cpu')
if self._audiogoal:
audio_shape = observation_space.spaces[audiogoal_sensor].shape
summary(self.audio_encoder.cnn, (audio_shape[2], audio_shape[0], audio_shape[1]), device='cpu')
self.train()