mae_envs/wrappers/food.py (115 lines of code) (raw):

import gym import numpy as np from mae_envs.wrappers.util import update_obs_space from mujoco_worldgen.util.types import store_args from gym.spaces import Tuple, MultiDiscrete class FoodHealthWrapper(gym.Wrapper): ''' Adds food health to underlying env. Manages food levels. Args: eat_thresh (float): radius within which food items can be eaten max_food_health (int): number of times a food item can be eaten before it disappears respawn_time (int): Number of time steps after which food items that have been eaten reappear food_rew_type (string): can be 'selfish': each agent gets an inividual reward for the food they eat 'joint_mean': food rewards are averaged over teams reward_scale (float or (float, float)): scales the reward by this amount. If tuple of floats, the exact reward scaling is uniformly sampled from (reward_scale[0], reward_scale[1]) at the beginning of every episode. reward_scale_obs (bool): If true, adds the reward scale for the current episode to food_obs ''' @store_args def __init__(self, env, eat_thresh=0.5, max_food_health=10, respawn_time=np.inf, food_rew_type='selfish', reward_scale=1.0, reward_scale_obs=False, split_eat_between_agents=False): super().__init__(env) self.n_agents = self.metadata['n_agents'] if type(reward_scale) not in [list, tuple, np.ndarray]: self.reward_scale = [reward_scale, reward_scale] # Reset obs/action space to match self.max_n_food = self.metadata['max_n_food'] self.curr_n_food = self.metadata['curr_n_food'] self.max_food_size = self.metadata['food_size'] food_dim = 5 if self.reward_scale_obs else 4 self.observation_space = update_obs_space(self.env, {'food_obs': (self.max_n_food, food_dim), 'food_health': (self.max_n_food, 1), 'food_eat': (self.max_n_food, 1)}) self.action_space.spaces['action_eat_food'] = Tuple([MultiDiscrete([2] * self.max_n_food) for _ in range(self.n_agents)]) def reset(self): obs = self.env.reset() sim = self.unwrapped.sim # Reset obs/action space to match self.curr_n_food = self.metadata['curr_n_food'] self.food_site_ids = np.array([sim.model.site_name2id(f'food{i}') for i in range(self.curr_n_food)]) # Reset food healths self.food_healths = np.ones((self.curr_n_food, 1)) * self.max_food_health self.eat_per_food = np.zeros((self.curr_n_food, 1)) # Reset food size self.respawn_counters = np.zeros((self.curr_n_food,)) self.curr_reward_scale = np.random.uniform(self.reward_scale[0], self.reward_scale[1]) return self.observation(obs) def observation(self, obs): # Add food position and healths to obersvations food_pos = obs['food_pos'] obs['food_health'] = self.food_healths obs['food_obs'] = np.concatenate([food_pos, self.food_healths], 1) if self.reward_scale_obs: obs['food_obs'] = np.concatenate([obs['food_obs'], np.ones((self.curr_n_food, 1)) * self.curr_reward_scale], 1) obs['food_eat'] = self.eat_per_food return obs def step(self, action): action_eat_food = action.pop('action_eat_food') obs, rew, done, info = self.env.step(action) if self.curr_n_food > 0: # Eat food that is close enough dist_to_food = np.linalg.norm(obs['agent_pos'][:, None] - obs['food_pos'][None], axis=-1) eat = np.logical_and(dist_to_food < self.eat_thresh, self.food_healths.T > 0) eat = np.logical_and(eat, action_eat_food).astype(np.float32) if self.split_eat_between_agents: eat_per_food = np.sum(eat, 0) eat[:, eat_per_food > 0] /= eat_per_food[eat_per_food > 0] eat_per_food = np.sum(eat, 0) # Make sure that all agents can't have the last bite of food. # At that point, food is split evenly over_eat = self.food_healths[:, 0] < eat_per_food eat[:, over_eat] *= (self.food_healths[over_eat, 0] / eat_per_food[over_eat]) eat_per_food = np.sum(eat, 0) self.eat_per_food = eat_per_food[:, None] # Update food healths and sizes self.food_healths -= eat_per_food[:, None] health_diff = eat_per_food[:, None] size_diff = health_diff * (self.max_food_size / self.max_food_health) size = self.unwrapped.sim.model.site_size[self.food_site_ids] - size_diff size = np.maximum(0, size) self.unwrapped.sim.model.site_size[self.food_site_ids] = size self.food_healths[self.respawn_counters == self.respawn_time] = self.max_food_health self.unwrapped.sim.model.site_size[self.food_site_ids[self.respawn_counters == self.respawn_time]] = self.max_food_size self.respawn_counters[self.food_healths[:, 0] == 0] += 1 self.respawn_counters[self.food_healths[:, 0] != 0] = 0 assert np.all(self.food_healths >= 0), \ f"There is a food health below 0: {self.food_healths}" # calculate food reward if self.food_rew_type == 'selfish': food_rew = np.sum(eat, axis=1) elif self.food_rew_type == 'joint_mean': food_rew = np.sum(eat, axis=1) team_index = self.metadata['team_index'] for team_index_number in np.unique(team_index): food_rew[team_index == team_index_number] = np.mean(food_rew[team_index == team_index_number]) else: raise ValueError(f"Food reward type {self.food_rew_type} unknown.") else: food_rew = 0.0 info['agents_eat'] = eat rew += food_rew * self.curr_reward_scale return self.observation(obs), rew, done, info class AlwaysEatWrapper(gym.ActionWrapper): ''' Remove eat action and replace it with always eating. Args: agent_idx_allowed (ndarray): indicies of agents allowed to eat. ''' def __init__(self, env, agent_idx_allowed): super().__init__(env) self.action_space.spaces.pop('action_eat_food') self.agent_idx_allowed = agent_idx_allowed def action(self, action): action['action_eat_food'] = np.zeros((self.metadata['n_agents'], self.metadata['curr_n_food'])) action['action_eat_food'][self.agent_idx_allowed] = 1. return action