in monobeast/minigrid/monobeast_amigo.py [0:0]
def forward(self, inputs, core_state=(), goal=[]):
"""Main Function, takes an observation and a goal and returns and action."""
# -- [unroll_length x batch_size x height x width x channels]
x = inputs["frame"]
T, B, h, w, *_ = x.shape
# -- [unroll_length*batch_size x height x width x channels]
x = torch.flatten(x, 0, 1) # Merge time and batch.
goal = torch.flatten(goal, 0, 1)
# Creating goal_channel
goal_channel = torch.zeros_like(x, requires_grad=False)
goal_channel = torch.flatten(goal_channel, 1,2)[:,:,0]
for i in range(goal.shape[0]):
goal_channel[i,goal[i]] = 1.0
goal_channel = goal_channel.view(T*B, h, w, 1)
carried_col = inputs["carried_col"]
carried_obj = inputs["carried_obj"]
if flags.disable_use_embedding:
x = x.float()
goal = goal.float()
carried_obj = carried_obj.float()
carried_col = carried_col.float()
else:
x = x.long()
goal = goal.long()
carried_obj = carried_obj.long()
carried_col = carried_col.long()
# -- [B x H x W x K]
x = torch.cat([self.create_embeddings(x, 0), self.create_embeddings(x, 1), self.create_embeddings(x, 2), goal_channel.float()], dim = 3)
carried_obj_emb = self._select(self.embed_object, carried_obj)
carried_col_emb = self._select(self.embed_color, carried_col)
if flags.no_generator:
goal_emb = torch.zeros(goal_emb.shape, dtype=goal_emb.dtype, device=goal_emb.device, requires_grad = False)
x = x.transpose(1, 3)
x = self.feat_extract(x)
x = x.view(T * B, -1)
carried_obj_emb = carried_obj_emb.view(T * B, -1)
carried_col_emb = carried_col_emb.view(T * B, -1)
union = torch.cat([x, carried_obj_emb, carried_col_emb], dim=1)
core_input = self.fc(union)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~inputs["done"]).float()
for input, nd in zip(core_input.unbind(), notdone.unbind()):
nd = nd.view(1, -1, 1)
core_state = tuple(nd * s for s in core_state)
output, core_state = self.core(input.unsqueeze(0), core_state)
core_output_list.append(output)
core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
else:
core_output = core_input
core_state = tuple()
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
return dict(policy_logits=policy_logits, baseline=baseline, action=action), core_state