def __init__()

in level_replay/model.py [0:0]


    def __init__(self, obs_shape, num_actions, arch='small', base_kwargs=None):
        super(MinigridPolicy, self).__init__()
        
        if base_kwargs is None:
            base_kwargs = {}
        
        final_channels = 32 if arch == 'small' else 64

        self.image_conv = nn.Sequential(
            nn.Conv2d(3, 16, (2, 2)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(16, 32, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(32, final_channels, (2, 2)),
            nn.ReLU()
        )
        n = obs_shape[-2]
        m = obs_shape[-1]
        self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*final_channels
        self.embedding_size = self.image_embedding_size

        # Define actor's model
        self.actor_base = nn.Sequential(
            init_tanh_(nn.Linear(self.embedding_size, 64)),
            nn.Tanh(),
        )

        # Define critic's model
        self.critic = nn.Sequential(
            init_tanh_(nn.Linear(self.embedding_size, 64)),
            nn.Tanh(),
            init_(nn.Linear(64, 1))
        )

        self.dist = Categorical(64, num_actions)

        apply_init_(self.modules())

        self.train()