in gala/model.py [0:0]
def __init__(self, obs_shape, action_space=None, base=None, base_kwargs=None, env_name=None):
super(Policy, self).__init__()
if base_kwargs is None:
base_kwargs = {}
if base is None:
if len(obs_shape) == 3:
base = CNNBase
elif len(obs_shape) == 1:
base = MLPBase
else:
raise NotImplementedError
self.base = base(obs_shape[0], **base_kwargs)
if action_space is None:
game = env_name[:env_name.find('NoFrameskip')]
num_actions = {
'BeamRider': 9,
'Breakout': 4,
'Pong': 6,
'Qbert': 6,
'Seaquest': 18,
'SpaceInvaders': 6,
}
num_outputs = num_actions[game]
self.dist = Categorical(self.base.output_size, num_outputs)
elif action_space.__class__.__name__ == "Discrete":
num_outputs = action_space.n
self.dist = Categorical(self.base.output_size, num_outputs)
elif action_space.__class__.__name__ == "Box":
num_outputs = action_space.shape[0]
self.dist = DiagGaussian(self.base.output_size, num_outputs)
elif action_space.__class__.__name__ == "MultiBinary":
num_outputs = action_space.shape[0]
self.dist = Bernoulli(self.base.output_size, num_outputs)
else:
raise NotImplementedError