def reached_goal_func()

in monobeast/minigrid/monobeast_amigo.py [0:0]


def reached_goal_func(frames, goals, initial_frames = None, done_aux = None):
    """Auxiliary function which evaluates whether agent has reached the goal."""
    if flags.modify:
        new_frame = torch.flatten(frames, 2, 3)
        old_frame = torch.flatten(initial_frames, 2, 3)
        ans = new_frame == old_frame
        ans = torch.sum(ans, 3) != 3  # reached if the three elements are not the same
        reached = torch.squeeze(torch.gather(ans, 2, torch.unsqueeze(goals.long(),2)))
        if flags.no_boundary_awareness:
            reached = reached.float() * (1 - done_aux.float())
        return reached
    else:    
        agent_location = torch.flatten(frames, 2, 3)
        agent_location = agent_location[:,:,:,0] 
        agent_location = (agent_location == 10).nonzero() # select object id
        agent_location = agent_location[:,2]
        agent_location = agent_location.view(goals.shape)
        return (goals == agent_location).float()