in monobeast/minigrid/monobeast_amigo.py [0:0]
def _select(self, embed, x):
"""Efficient function to get embedding from an index."""
if self.use_index_select:
out = embed.weight.index_select(0, x.reshape(-1))
# handle reshaping x to 1-d and output back to N-d
return out.reshape(x.shape +(-1,))
else:
return embed(x)