in monobeast/minigrid/monobeast_amigo.py [0:0]
def forward(self, inputs):
"""Main Function, takes an observation and returns a goal."""
x = inputs["frame"]
T, B, *_ = x.shape
carried_col = inputs["carried_col"]
carried_obj = inputs["carried_obj"]
x = torch.flatten(x, 0, 1) # Merge time and batch.
if flags.disable_use_embedding:
x = x.float()
carried_obj = carried_obj.float()
carried_col = carried_col.float()
else:
x = x.long()
carried_obj = carried_obj.long()
carried_col = carried_col.long()
x = torch.cat([self.create_embeddings(x, 0), self.create_embeddings(x, 1), self.create_embeddings(x, 2)], dim = 3)
carried_obj_emb = self._select(self.embed_object, carried_obj)
carried_col_emb = self._select(self.embed_color, carried_col)
x = x.transpose(1, 3)
carried_obj_emb = carried_obj_emb.view(T * B, -1)
carried_col_emb = carried_col_emb.view(T * B, -1)
x = self.extract_representation(x)
x = x.view(T * B, -1)
generator_logits = x.view(T*B, -1)
generator_baseline = self.baseline_teacher(generator_logits)
goal = torch.multinomial(F.softmax(generator_logits, dim=1), num_samples=1)
generator_logits = generator_logits.view(T, B, -1)
generator_baseline = generator_baseline.view(T, B)
goal = goal.view(T, B)
if flags.inner:
goal = self.convert_inner(goal)
return dict(goal=goal, generator_logits=generator_logits, generator_baseline=generator_baseline)