def __init__()

in ss_baselines/av_wan/ppo/policy.py [0:0]


    def __init__(self, observation_space, hidden_size, goal_sensor_uuid, encode_rgb, encode_depth):
        super().__init__()
        self.goal_sensor_uuid = goal_sensor_uuid
        self._hidden_size = hidden_size
        self._spectrogram = False
        self._gm = 'gm' in observation_space.spaces
        self._am = 'am' in observation_space.spaces

        self._spectrogram = 'spectrogram' == self.goal_sensor_uuid
        self.visual_encoder = VisualCNN(observation_space, hidden_size, encode_rgb, encode_depth)
        if self._spectrogram:
            self.audio_encoder = AudioCNN(observation_space, hidden_size)
        if self._gm:
            self.gm_encoder = MapCNN(observation_space, hidden_size, map_type='gm')
        if self._am:
            self.am_encoder = MapCNN(observation_space, hidden_size, map_type='am')

        rnn_input_size = (0 if self.is_blind else self._hidden_size) + \
                         (self._hidden_size if self._spectrogram else 0) + \
                         (self._hidden_size if self._gm else 0) + \
                         (self._hidden_size if self._am else 0)
        self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size)

        if 'rgb' in observation_space.spaces and encode_rgb:
            rgb_shape = observation_space.spaces['rgb'].shape
            summary(self.visual_encoder.cnn, (rgb_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu')
        if 'depth' in observation_space.spaces and encode_depth:
            depth_shape = observation_space.spaces['depth'].shape
            summary(self.visual_encoder.cnn, (depth_shape[2], depth_shape[0], depth_shape[1]), device='cpu')
        if 'spectrogram' in observation_space.spaces:
            audio_shape = observation_space.spaces['spectrogram'].shape
            summary(self.audio_encoder.cnn, (audio_shape[2], audio_shape[0], audio_shape[1]), device='cpu')
        if self._gm:
            gm_shape = observation_space.spaces['gm'].shape
            summary(self.gm_encoder.cnn, (gm_shape[2], gm_shape[0], gm_shape[1]), device='cpu')
        if self._am:
            am_shape = observation_space.spaces['am'].shape
            summary(self.am_encoder.cnn, (am_shape[2], am_shape[0], am_shape[1]), device='cpu')

        self.train()