def forward()

in monobeast/minigrid/monobeast_amigo.py [0:0]


    def forward(self, inputs):
        """Main Function, takes an observation and returns a goal."""
        x = inputs["frame"]
        T, B, *_ = x.shape
        carried_col = inputs["carried_col"]
        carried_obj = inputs["carried_obj"]


        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        if flags.disable_use_embedding:
            x = x.float() 
            carried_obj = carried_obj.float()
            carried_col = carried_col.float()
        else:    
            x = x.long()
            carried_obj = carried_obj.long()
            carried_col = carried_col.long()
            x = torch.cat([self.create_embeddings(x, 0), self.create_embeddings(x, 1), self.create_embeddings(x, 2)], dim = 3)
            carried_obj_emb = self._select(self.embed_object, carried_obj)
            carried_col_emb = self._select(self.embed_color, carried_col)
              
        x = x.transpose(1, 3)
        carried_obj_emb = carried_obj_emb.view(T * B, -1)
        carried_col_emb = carried_col_emb.view(T * B, -1)

        x = self.extract_representation(x)
        x = x.view(T * B, -1)

        generator_logits = x.view(T*B, -1)

        generator_baseline = self.baseline_teacher(generator_logits)
        
        goal = torch.multinomial(F.softmax(generator_logits, dim=1), num_samples=1)

        generator_logits = generator_logits.view(T, B, -1)
        generator_baseline = generator_baseline.view(T, B)
        goal = goal.view(T, B)

        if flags.inner:
            goal = self.convert_inner(goal)

        return dict(goal=goal, generator_logits=generator_logits, generator_baseline=generator_baseline)