in ss_baselines/av_wan/ppo/policy.py [0:0]
def __init__(self, observation_space, hidden_size, goal_sensor_uuid, encode_rgb, encode_depth):
super().__init__()
self.goal_sensor_uuid = goal_sensor_uuid
self._hidden_size = hidden_size
self._spectrogram = False
self._gm = 'gm' in observation_space.spaces
self._am = 'am' in observation_space.spaces
self._spectrogram = 'spectrogram' == self.goal_sensor_uuid
self.visual_encoder = VisualCNN(observation_space, hidden_size, encode_rgb, encode_depth)
if self._spectrogram:
self.audio_encoder = AudioCNN(observation_space, hidden_size)
if self._gm:
self.gm_encoder = MapCNN(observation_space, hidden_size, map_type='gm')
if self._am:
self.am_encoder = MapCNN(observation_space, hidden_size, map_type='am')
rnn_input_size = (0 if self.is_blind else self._hidden_size) + \
(self._hidden_size if self._spectrogram else 0) + \
(self._hidden_size if self._gm else 0) + \
(self._hidden_size if self._am else 0)
self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size)
if 'rgb' in observation_space.spaces and encode_rgb:
rgb_shape = observation_space.spaces['rgb'].shape
summary(self.visual_encoder.cnn, (rgb_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu')
if 'depth' in observation_space.spaces and encode_depth:
depth_shape = observation_space.spaces['depth'].shape
summary(self.visual_encoder.cnn, (depth_shape[2], depth_shape[0], depth_shape[1]), device='cpu')
if 'spectrogram' in observation_space.spaces:
audio_shape = observation_space.spaces['spectrogram'].shape
summary(self.audio_encoder.cnn, (audio_shape[2], audio_shape[0], audio_shape[1]), device='cpu')
if self._gm:
gm_shape = observation_space.spaces['gm'].shape
summary(self.gm_encoder.cnn, (gm_shape[2], gm_shape[0], gm_shape[1]), device='cpu')
if self._am:
am_shape = observation_space.spaces['am'].shape
summary(self.am_encoder.cnn, (am_shape[2], am_shape[0], am_shape[1]), device='cpu')
self.train()