in level_replay/model.py [0:0]
def __init__(self, obs_shape, num_actions, arch='small', base_kwargs=None):
super(MinigridPolicy, self).__init__()
if base_kwargs is None:
base_kwargs = {}
final_channels = 32 if arch == 'small' else 64
self.image_conv = nn.Sequential(
nn.Conv2d(3, 16, (2, 2)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, final_channels, (2, 2)),
nn.ReLU()
)
n = obs_shape[-2]
m = obs_shape[-1]
self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*final_channels
self.embedding_size = self.image_embedding_size
# Define actor's model
self.actor_base = nn.Sequential(
init_tanh_(nn.Linear(self.embedding_size, 64)),
nn.Tanh(),
)
# Define critic's model
self.critic = nn.Sequential(
init_tanh_(nn.Linear(self.embedding_size, 64)),
nn.Tanh(),
init_(nn.Linear(64, 1))
)
self.dist = Categorical(64, num_actions)
apply_init_(self.modules())
self.train()