in minihack/agent/polybeast/models/intrinsic.py [0:0]
def forward(self, inputs, core_state, learning=False):
if not learning:
# no need to calculate RIDE outputs when not in learn step
return super(RIDENet, self).forward(inputs, core_state, learning)
T, B, *_ = inputs["glyphs"].shape
glyphs, features = self.prepare_input(inputs)
# -- [B x 2] x,y coordinates
coordinates = features[:, :2]
features = features.view(T * B, -1).float()
# -- [B x K]
features_emb = self.embed_features(features)
assert features_emb.shape[0] == T * B
reps = [features_emb]
# -- [B x H' x W']
crop = self.glyph_embedding.GlyphTuple(
*[self.crop(g, coordinates) for g in glyphs]
)
# -- [B x H' x W' x K]
crop_emb = self.glyph_embedding(crop)
if self.crop_model == "transformer":
# -- [B x W' x H' x K]
crop_rep = self.extract_crop_representation(crop_emb, mask=None)
elif self.crop_model == "cnn":
# -- [B x K x W' x H']
crop_emb = crop_emb.transpose(1, 3)
# -- [B x W' x H' x K]
crop_rep = self.extract_crop_representation(crop_emb)
# -- [B x K']
crop_rep = crop_rep.view(T * B, -1)
assert crop_rep.shape[0] == T * B
reps.append(crop_rep)
# -- [B x H x W x K]
glyphs_emb = self.glyph_embedding(glyphs)
# glyphs_emb = self.embed(glyphs)
# -- [B x K x W x H]
glyphs_emb = glyphs_emb.transpose(1, 3)
# -- [B x W x H x K]
glyphs_rep = self.extract_representation(glyphs_emb)
# -- [B x K']
glyphs_rep = glyphs_rep.view(T * B, -1)
assert glyphs_rep.shape[0] == T * B
# -- [B x K'']
reps.append(glyphs_rep)
st = torch.cat(reps, dim=1)
# -- [B x K]
st = self.fc(st)
# PREDICTOR NETWORK
if self.intrinsic_input == "crop_only":
ride_crop_emb = self.ride_embed(crop).transpose(1, 3)
ride_crop_rep = self.ride_extract_crop_representation(
ride_crop_emb
)
ride_st = self.ride_fc(ride_crop_rep.view(T * B, -1))
elif self.intrinsic_input == "glyph_only":
ride_glyphs_emb = self.ride_embed(glyphs).transpose(1, 3)
ride_glyphs_rep = self.ride_extract_representation(ride_glyphs_emb)
ride_st = self.ride_fc(ride_glyphs_rep.view(T * B, -1))
else: # full
ride_reps = []
ride_feats = self.ride_embed_features(features)
ride_reps.append(ride_feats)
ride_crop_emb = self.ride_embed(crop).transpose(1, 3)
ride_crop_rep = self.ride_extract_crop_representation(
ride_crop_emb
)
ride_reps.append(ride_crop_rep.view(T * B, -1))
ride_glyphs_emb = self.ride_embed(glyphs).transpose(1, 3)
ride_glyphs_rep = self.ride_extract_representation(ride_glyphs_emb)
ride_reps.append(ride_glyphs_rep.view(T * B, -1))
ride_st = self.ride_fc(torch.cat(ride_reps, dim=1))
if self.use_lstm:
core_input = st.view(T, B, -1)
core_output_list = []
notdone = (~inputs["done"]).float()
for input, nd in zip(core_input.unbind(), notdone.unbind()):
# Reset core state to zero whenever an episode ended.
# Make `done` broadcastable with (num_layers, B, hidden_size)
# states:
nd = nd.view(1, -1, 1)
core_state = tuple(nd * t for t in core_state)
output, core_state = self.core(input.unsqueeze(0), core_state)
core_output_list.append(output)
core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
else:
core_output = st
# -- [B x A]
policy_logits = self.policy(core_output)
# -- [B x A]
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(
F.softmax(policy_logits, dim=1), num_samples=1
)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
output = dict(
policy_logits=policy_logits,
baseline=baseline,
action=action,
state_embedding=ride_st.view(T, B, -1),
int_baseline=self.int_baseline(core_output).view(T, B),
)
return (output, core_state)