in monobeast/minigrid/monobeast_amigo.py [0:0]
def reached_goal_func(frames, goals, initial_frames = None, done_aux = None):
"""Auxiliary function which evaluates whether agent has reached the goal."""
if flags.modify:
new_frame = torch.flatten(frames, 2, 3)
old_frame = torch.flatten(initial_frames, 2, 3)
ans = new_frame == old_frame
ans = torch.sum(ans, 3) != 3 # reached if the three elements are not the same
reached = torch.squeeze(torch.gather(ans, 2, torch.unsqueeze(goals.long(),2)))
if flags.no_boundary_awareness:
reached = reached.float() * (1 - done_aux.float())
return reached
else:
agent_location = torch.flatten(frames, 2, 3)
agent_location = agent_location[:,:,:,0]
agent_location = (agent_location == 10).nonzero() # select object id
agent_location = agent_location[:,2]
agent_location = agent_location.view(goals.shape)
return (goals == agent_location).float()