in minihack/agent/polybeast/models/base.py [0:0]
def __init__(self, observation_shape, num_actions, flags, device):
super(BaseNet, self).__init__()
self.flags = flags
self.observation_shape = observation_shape
self.H = observation_shape[0]
self.W = observation_shape[1]
self.num_actions = num_actions
self.use_lstm = flags.use_lstm
self.k_dim = flags.embedding_dim
self.h_dim = flags.hidden_dim
self.crop_model = flags.crop_model
self.crop_dim = flags.crop_dim
self.num_features = NUM_FEATURES
self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim, device)
self.glyph_type = flags.glyph_type
self.glyph_embedding = GlyphEmbedding(
flags.glyph_type,
flags.embedding_dim,
device,
flags.use_index_select,
)
K = flags.embedding_dim # number of input filters
F = 3 # filter dimensions
S = 1 # stride
P = 1 # padding
M = 16 # number of intermediate filters
self.Y = 8 # number of output filters
L = flags.layers # number of convnet layers
in_channels = [K] + [M] * (L - 1)
out_channels = [M] * (L - 1) + [self.Y]
def interleave(xs, ys):
return [val for pair in zip(xs, ys) for val in pair]
conv_extract = [
nn.Conv2d(
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=(F, F),
stride=S,
padding=P,
)
for i in range(L)
]
self.extract_representation = nn.Sequential(
*interleave(conv_extract, [nn.ELU()] * len(conv_extract))
)
if self.crop_model == "transformer":
self.extract_crop_representation = TransformerEncoder(
K,
N=L,
heads=8,
height=self.crop_dim,
width=self.crop_dim,
device=device,
)
elif self.crop_model == "cnn":
conv_extract_crop = [
nn.Conv2d(
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=(F, F),
stride=S,
padding=P,
)
for i in range(L)
]
self.extract_crop_representation = nn.Sequential(
*interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract))
)
# MESSAGING MODEL
if "msg" not in flags:
self.msg_model = "none"
else:
self.msg_model = flags.msg.model
self.msg_hdim = flags.msg.hidden_dim
self.msg_edim = flags.msg.embedding_dim
if self.msg_model in ("gru", "lstm", "lt_cnn"):
# character-based embeddings
self.char_lt = nn.Embedding(
NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
)
else:
# forward will set up one-hot inputs for the cnn, no lt needed
pass
if self.msg_model.endswith("cnn"):
# from Zhang et al, 2016
# Character-level Convolutional Networks for Text Classification
# https://arxiv.org/abs/1509.01626
if self.msg_model == "cnn":
# inputs will be one-hot vectors, as done in paper
self.conv1 = nn.Conv1d(NUM_CHARS, self.msg_hdim, kernel_size=7)
elif self.msg_model == "lt_cnn":
# replace one-hot inputs with learned embeddings
self.conv1 = nn.Conv1d(
self.msg_edim, self.msg_hdim, kernel_size=7
)
else:
raise NotImplementedError("msg.model == %s", flags.msg.model)
# remaining convolutions, relus, pools, and a small FC network
self.conv2_6_fc = nn.Sequential(
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# conv2
nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=7),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# conv3
nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
nn.ReLU(),
# conv4
nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
nn.ReLU(),
# conv5
nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
nn.ReLU(),
# conv6
nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=3),
# fc receives -- [ B x h_dim x 5 ]
Flatten(),
nn.Linear(5 * self.msg_hdim, 2 * self.msg_hdim),
nn.ReLU(),
nn.Linear(2 * self.msg_hdim, self.msg_hdim),
) # final output -- [ B x h_dim x 5 ]
elif self.msg_model in ("gru", "lstm"):
def rnn(flag):
return nn.LSTM if flag == "lstm" else nn.GRU
self.char_rnn = rnn(self.msg_model)(
self.msg_edim,
self.msg_hdim // 2,
batch_first=True,
bidirectional=True,
)
elif self.msg_model != "none":
raise NotImplementedError("msg.model == %s", flags.msg.model)
self.embed_features = nn.Sequential(
nn.Linear(self.num_features, self.k_dim),
nn.ReLU(),
nn.Linear(self.k_dim, self.k_dim),
nn.ReLU(),
)
self.equalize_input_dim = flags.equalize_input_dim
if not self.equalize_input_dim:
# just added up the output dimensions of the input featurizers
# feature / status dim
out_dim = self.k_dim
# CNN over full glyph map
out_dim += self.H * self.W * self.Y
if self.crop_model == "transformer":
out_dim += self.crop_dim ** 2 * K
elif self.crop_model == "cnn":
out_dim += self.crop_dim ** 2 * self.Y
# messaging model
if self.msg_model != "none":
out_dim += self.msg_hdim
else:
# otherwise, project them all to h_dim
NUM_INPUTS = 4 if self.msg_model != "none" else 3
project_hdim = flags.equalize_factor * self.h_dim
out_dim = project_hdim * NUM_INPUTS
# set up linear layers for projections
self.project_feature_dim = nn.Linear(self.k_dim, project_hdim)
self.project_glyph_dim = nn.Linear(
self.H * self.W * self.Y, project_hdim
)
c__2 = self.crop_dim ** 2
if self.crop_model == "transformer":
self.project_crop_dim = nn.Linear(c__2 * K, project_hdim)
elif self.crop_model == "cnn":
self.project_crop_dim = nn.Linear(c__2 * self.Y, project_hdim)
if self.msg_model != "none":
self.project_msg_dim = nn.Linear(self.msg_hdim, project_hdim)
self.fc = nn.Sequential(
nn.Linear(out_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(),
)
if self.use_lstm:
self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)
self.policy = nn.Linear(self.h_dim, self.num_actions)
self.baseline = nn.Linear(self.h_dim, 1)