in monobeast/minigrid/monobeast_amigo.py [0:0]
def __init__(self, observation_shape, width, height, num_input_frames, hidden_dim=256):
super(Generator, self).__init__()
self.observation_shape = observation_shape
self.height = height
self.width = width
self.env_dim = self.width * self.height
self.state_embedding_dim = 256
self.use_index_select = True
self.obj_dim = 5
self.col_dim = 3
self.con_dim = 2
self.num_channels = (self.obj_dim + self.col_dim + self.con_dim) * num_input_frames
if flags.disable_use_embedding:
print("not_using_embedding")
self.num_channels = 3*num_input_frames
self.embed_object = nn.Embedding(11, self.obj_dim)
self.embed_color = nn.Embedding(6, self.col_dim)
self.embed_contains = nn.Embedding(4, self.con_dim)
K = self.num_channels # number of input filters
F = 3 # filter dimensions
S = 1 # stride
P = 1 # padding
M = 16 # number of intermediate filters
Y = 8 # number of output filters
L = 4 # number of convnet layers
E = 1 # output of last layer
in_channels = [K] + [M] * 4
out_channels = [M] * 3 + [E]
conv_extract = [
nn.Conv2d(
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=(F, F),
stride=S,
padding=P,
)
for i in range(L)
]
def interleave(xs, ys):
return [val for pair in zip(xs, ys) for val in pair]
self.extract_representation = nn.Sequential(
*interleave(conv_extract, [nn.ELU()] * len(conv_extract))
)
self.out_dim = self.env_dim * 16 + self.obj_dim + self.col_dim
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0))
if flags.inner:
self.aux_env_dim = (self.height-2) * (self.width-2)
else:
self.aux_env_dim = self.env_dim
self.baseline_teacher = init_(nn.Linear(self.aux_env_dim, 1))