mae_envs/wrappers/line_of_sight.py (67 lines of code) (raw):
import gym
import numpy as np
from mae_envs.util.vision import insight, in_cone2d
from mae_envs.wrappers.util import update_obs_space
class AgentAgentObsMask2D(gym.ObservationWrapper):
""" Adds an mask observation that states which agents are visible to which agents.
Args:
cone_angle: (float) the angle in radians btw the axis and edge of the observation cone
"""
def __init__(self, env, cone_angle=3/8 * np.pi):
super().__init__(env)
self.cone_angle = cone_angle
self.n_agents = self.unwrapped.n_agents
self.observation_space = update_obs_space(env, {'mask_aa_obs': (self.n_agents, self.n_agents)})
def observation(self, obs):
# Agent to agent obs mask
agent_pos2d = obs['agent_pos'][:, :-1]
agent_angle = obs['agent_angle']
cone_mask = in_cone2d(agent_pos2d, np.squeeze(agent_angle, -1), self.cone_angle, agent_pos2d)
# Make sure they are in line of sight
for i, j in np.argwhere(cone_mask):
if i != j:
cone_mask[i, j] = insight(self.unwrapped.sim,
self.metadata['agent_geom_idxs'][i],
self.metadata['agent_geom_idxs'][j])
obs['mask_aa_obs'] = cone_mask
return obs
class AgentSiteObsMask2D(gym.ObservationWrapper):
""" Adds an mask observation that states which sites are visible to which agents.
Args:
pos_obs_key: (string) the name of the site position observation of shape (n_sites, 3)
mask_obs_key: (string) the name of the mask observation to output
cone_angle: (float) the angle in radians btw the axis and edge of the observation cone
"""
def __init__(self, env, pos_obs_key, mask_obs_key, cone_angle=3/8 * np.pi):
super().__init__(env)
self.cone_angle = cone_angle
self.n_agents = self.unwrapped.n_agents
assert(self.n_agents == self.observation_space.spaces['agent_pos'].shape[0])
self.n_objects = self.observation_space.spaces[pos_obs_key].shape[0]
self.observation_space = update_obs_space(env, {mask_obs_key: (self.n_agents, self.n_objects)})
self.pos_obs_key = pos_obs_key
self.mask_obs_key = mask_obs_key
def observation(self, obs):
agent_pos2d = obs['agent_pos'][:, :-1]
agent_angle = obs['agent_angle']
pos2d = obs[self.pos_obs_key][:, :2]
cone_mask = in_cone2d(agent_pos2d, np.squeeze(agent_angle, -1), self.cone_angle, pos2d)
# Make sure they are in line of sight
for i, j in np.argwhere(cone_mask):
agent_geom_id = self.metadata['agent_geom_idxs'][i]
pt2 = obs[self.pos_obs_key][j]
cone_mask[i, j] = insight(self.unwrapped.sim, agent_geom_id, pt2=pt2)
obs[self.mask_obs_key] = cone_mask
return obs
class AgentGeomObsMask2D(gym.ObservationWrapper):
""" Adds an mask observation that states which geoms are visible to which agents.
Args:
pos_obs_key: (string) the name of the site position observation of shape (n_geoms, 3)
geom_idxs_obs_key: (string) the name of an observation that, for each object to be
masked, gives the Mujoco index of the geom (e.g. in sim.geom_names)
as an array of shape (n_geoms, 1)
mask_obs_key: (string) the name of the mask observation to output
cone_angle: (float) the angle in radians btw the axis and edge of the observation cone
"""
def __init__(self, env, pos_obs_key, geom_idxs_obs_key, mask_obs_key, cone_angle=3/8 * np.pi):
super().__init__(env)
self.cone_angle = cone_angle
self.n_agents = self.unwrapped.n_agents
assert(self.n_agents == self.observation_space.spaces['agent_pos'].shape[0])
self.n_objects = self.observation_space.spaces[pos_obs_key].shape[0]
self.observation_space = update_obs_space(env, {mask_obs_key: (self.n_agents, self.n_objects)})
self.pos_obs_key = pos_obs_key
self.mask_obs_key = mask_obs_key
self.geom_idxs_obs_key = geom_idxs_obs_key
def observation(self, obs):
agent_pos2d = obs['agent_pos'][:, :-1]
agent_angle = obs['agent_angle']
pos2d = obs[self.pos_obs_key][:, :2]
cone_mask = in_cone2d(agent_pos2d, np.squeeze(agent_angle, -1), self.cone_angle, pos2d)
# Make sure they are in line of sight
for i, j in np.argwhere(cone_mask):
agent_geom_id = self.metadata['agent_geom_idxs'][i]
geom_id = obs[self.geom_idxs_obs_key][j, 0]
if geom_id == -1:
# This option is helpful if the number of geoms varies between episodes
# If geoms don't exists this wrapper expects that the geom idx is
# set to -1
cone_mask[i, j] = 0
else:
cone_mask[i, j] = insight(self.unwrapped.sim, agent_geom_id, geom2_id=geom_id)
obs[self.mask_obs_key] = cone_mask
return obs