def __init__()

in gala/model.py [0:0]


    def __init__(self, obs_shape, action_space=None, base=None, base_kwargs=None, env_name=None):
        super(Policy, self).__init__()
        if base_kwargs is None:
            base_kwargs = {}
        if base is None:
            if len(obs_shape) == 3:
                base = CNNBase
            elif len(obs_shape) == 1:
                base = MLPBase
            else:
                raise NotImplementedError

        self.base = base(obs_shape[0], **base_kwargs)

        if action_space is None:
            game = env_name[:env_name.find('NoFrameskip')]
            num_actions = {
                'BeamRider': 9,
                'Breakout': 4,
                'Pong': 6,
                'Qbert': 6,
                'Seaquest': 18,
                'SpaceInvaders': 6,
            }
            num_outputs = num_actions[game]
            self.dist = Categorical(self.base.output_size, num_outputs)
        elif action_space.__class__.__name__ == "Discrete":
            num_outputs = action_space.n
            self.dist = Categorical(self.base.output_size, num_outputs)
        elif action_space.__class__.__name__ == "Box":
            num_outputs = action_space.shape[0]
            self.dist = DiagGaussian(self.base.output_size, num_outputs)
        elif action_space.__class__.__name__ == "MultiBinary":
            num_outputs = action_space.shape[0]
            self.dist = Bernoulli(self.base.output_size, num_outputs)
        else:
            raise NotImplementedError