in monobeast/minigrid/monobeast_amigo.py [0:0]
def create_embeddings(self, x, id):
"""Generates compositional embeddings."""
if id == 0:
objects_emb = self._select(self.embed_object, x[:,:,:,id::3])
elif id == 1:
objects_emb = self._select(self.embed_color, x[:,:,:,id::3])
elif id == 2:
objects_emb = self._select(self.embed_contains, x[:,:,:,id::3])
embeddings = torch.flatten(objects_emb, 3, 4)
return embeddings