mae_envs/modules/food.py (53 lines of code) (raw):
import numpy as np
from mujoco_worldgen.util.types import store_args
from mae_envs.modules import EnvModule, rejection_placement
class Food(EnvModule):
'''
Add food sites to the environment.
Args:
n_food (int or (int, int)): number of food items. If tuple of ints, every episode the
number of food items is drawn uniformly from range(n_food[0], n_food[1] + 1)
food_size (float): (visual) size of food items
placement_fn (fn or list of fns): See mae_envs.modules.util:rejection_placement for spec
If list of functions, then it is assumed there is one function given per food site
'''
@store_args
def __init__(self, n_food, food_size=0.1, placement_fn=None):
if type(n_food) not in [tuple, list, np.ndarray]:
self.n_food = [n_food, n_food]
pass
def build_world_step(self, env, floor, floor_size):
env.metadata['food_size'] = self.food_size
self.curr_n_food = env._random_state.randint(self.n_food[0], self.n_food[1] + 1)
env.metadata['max_n_food'] = self.n_food[1]
env.metadata['curr_n_food'] = self.curr_n_food
successful_placement = True
for i in range(self.curr_n_food):
env.metadata.pop(f"food{i}_initpos", None)
# Add food sites
for i in range(self.curr_n_food):
if self.placement_fn is not None:
_placement_fn = (self.placement_fn[i]
if isinstance(self.placement_fn, list)
else self.placement_fn)
pos, pos_grid = rejection_placement(env, _placement_fn, floor_size,
np.array([self.food_size, self.food_size]))
if pos is not None:
floor.mark(f"food{i}", relative_xyz=np.append(pos, [self.food_size / 2]),
size=(self.food_size, self.food_size, self.food_size),
rgba=(0., 1., 0., 1.))
# store spawn position in metadata. This allows sampling subsequent food items
# close to previous food items
env.metadata[f"food{i}_initpos"] = pos_grid
else:
successful_placement = False
else:
floor.mark(f"food{i}", rgba=(0., 1., 0., 1.),
size=(self.food_size, self.food_size, self.food_size))
return successful_placement
def modify_sim_step(self, env, sim):
self.food_site_ids = np.array([sim.model.site_name2id(f'food{i}')
for i in range(self.curr_n_food)])
def observation_step(self, env, sim):
if self.curr_n_food > 0:
obs = {'food_pos': sim.data.site_xpos[self.food_site_ids]}
else:
obs = {'food_pos': np.zeros((0, 3))}
return obs