in minihack/agent/rllib/models.py [0:0]
def forward(self, inputs):
B, *_ = inputs["glyphs"].shape
glyphs, features = self.prepare_input(inputs)
# -- [B x 2] x,y coordinates
coordinates = features[:, :2]
# -- [B x K]
features_emb = self.embed_features(features)
if self.equalize_input_dim:
features_emb = self.project_feature_dim(features_emb)
assert features_emb.shape[0] == B
reps = [features_emb] # either k_dim or project_hdim
# -- [B x H' x W']
crop = self.glyph_embedding.GlyphTuple(
*[self.crop(g, coordinates) for g in glyphs]
)
# -- [B x H' x W' x K]
crop_emb = self.glyph_embedding(crop)
if self.crop_model == "transformer":
# -- [B x W' x H' x K]
crop_rep = self.extract_crop_representation(crop_emb, mask=None)
elif self.crop_model == "cnn":
# -- [B x K x W' x H']
crop_emb = crop_emb.transpose(1, 3)
# -- [B x W' x H' x K]
crop_rep = self.extract_crop_representation(crop_emb)
# -- [B x K']
crop_rep = crop_rep.view(B, -1)
if self.equalize_input_dim:
crop_rep = self.project_crop_dim(crop_rep)
assert crop_rep.shape[0] == B
reps.append(crop_rep) # either k_dim or project_hdim
# -- [B x H x W x K]
glyphs_emb = self.glyph_embedding(glyphs)
# glyphs_emb = self.embed(glyphs)
# -- [B x K x W x H]
glyphs_emb = glyphs_emb.transpose(1, 3)
# -- [B x W x H x K]
glyphs_rep = self.extract_representation(glyphs_emb)
# -- [B x K']
glyphs_rep = glyphs_rep.view(B, -1)
# -- [B x K']
if self.equalize_input_dim:
glyphs_rep = self.project_glyph_dim(glyphs_rep)
assert glyphs_rep.shape[0] == B
# -- [B x K'']
reps.append(glyphs_rep)
# MESSAGING MODEL
if self.msg_model != "none":
messages = inputs["message"].long()
if self.msg_model == "cnn":
# convert messages to one-hot, [B x 96 x 256]
one_hot = F.one_hot(messages, num_classes=NUM_CHARS).transpose(
1, 2
)
char_rep = self.conv2_6_fc(self.conv1(one_hot.float()))
elif self.msg_model == "lt_cnn":
# [B x E x 256 ]
char_emb = self.char_lt(messages).transpose(1, 2)
char_rep = self.conv2_6_fc(self.conv1(char_emb))
else: # lstm, gru
char_emb = self.char_lt(messages)
output = self.char_rnn(char_emb)[0]
fwd_rep = output[:, -1, : self.h_dim // 2]
bwd_rep = output[:, 0, self.h_dim // 2 :]
char_rep = torch.cat([fwd_rep, bwd_rep], dim=1)
if self.equalize_input_dim:
char_rep = self.project_msg_dim(char_rep)
reps.append(char_rep)
st = torch.cat(reps, dim=1)
# -- [B x K]
st = self.fc(st)
return st