in mae_envs/wrappers/food.py [0:0]
def step(self, action):
action_eat_food = action.pop('action_eat_food')
obs, rew, done, info = self.env.step(action)
if self.curr_n_food > 0:
# Eat food that is close enough
dist_to_food = np.linalg.norm(obs['agent_pos'][:, None] - obs['food_pos'][None], axis=-1)
eat = np.logical_and(dist_to_food < self.eat_thresh, self.food_healths.T > 0)
eat = np.logical_and(eat, action_eat_food).astype(np.float32)
if self.split_eat_between_agents:
eat_per_food = np.sum(eat, 0)
eat[:, eat_per_food > 0] /= eat_per_food[eat_per_food > 0]
eat_per_food = np.sum(eat, 0)
# Make sure that all agents can't have the last bite of food.
# At that point, food is split evenly
over_eat = self.food_healths[:, 0] < eat_per_food
eat[:, over_eat] *= (self.food_healths[over_eat, 0] / eat_per_food[over_eat])
eat_per_food = np.sum(eat, 0)
self.eat_per_food = eat_per_food[:, None]
# Update food healths and sizes
self.food_healths -= eat_per_food[:, None]
health_diff = eat_per_food[:, None]
size_diff = health_diff * (self.max_food_size / self.max_food_health)
size = self.unwrapped.sim.model.site_size[self.food_site_ids] - size_diff
size = np.maximum(0, size)
self.unwrapped.sim.model.site_size[self.food_site_ids] = size
self.food_healths[self.respawn_counters == self.respawn_time] = self.max_food_health
self.unwrapped.sim.model.site_size[self.food_site_ids[self.respawn_counters == self.respawn_time]] = self.max_food_size
self.respawn_counters[self.food_healths[:, 0] == 0] += 1
self.respawn_counters[self.food_healths[:, 0] != 0] = 0
assert np.all(self.food_healths >= 0), \
f"There is a food health below 0: {self.food_healths}"
# calculate food reward
if self.food_rew_type == 'selfish':
food_rew = np.sum(eat, axis=1)
elif self.food_rew_type == 'joint_mean':
food_rew = np.sum(eat, axis=1)
team_index = self.metadata['team_index']
for team_index_number in np.unique(team_index):
food_rew[team_index == team_index_number] = np.mean(food_rew[team_index == team_index_number])
else:
raise ValueError(f"Food reward type {self.food_rew_type} unknown.")
else:
food_rew = 0.0
info['agents_eat'] = eat
rew += food_rew * self.curr_reward_scale
return self.observation(obs), rew, done, info