in level_replay/model.py [0:0]
def __init__(self, obs_shape, num_actions, arch='small', base_kwargs=None):
super(Policy, self).__init__()
if base_kwargs is None:
base_kwargs = {}
if len(obs_shape) == 3:
if arch == 'small':
base = SmallNetBase
else:
base = ResNetBase
elif len(obs_shape) == 1:
base = MLPBase
self.base = base(obs_shape[0], **base_kwargs)
self.dist = Categorical(self.base.output_size, num_actions)