def step()

in mae_envs/wrappers/food.py [0:0]


    def step(self, action):
        action_eat_food = action.pop('action_eat_food')
        obs, rew, done, info = self.env.step(action)

        if self.curr_n_food > 0:
            # Eat food that is close enough
            dist_to_food = np.linalg.norm(obs['agent_pos'][:, None] - obs['food_pos'][None], axis=-1)
            eat = np.logical_and(dist_to_food < self.eat_thresh, self.food_healths.T > 0)
            eat = np.logical_and(eat, action_eat_food).astype(np.float32)
            if self.split_eat_between_agents:
                eat_per_food = np.sum(eat, 0)
                eat[:, eat_per_food > 0] /= eat_per_food[eat_per_food > 0]
            eat_per_food = np.sum(eat, 0)

            # Make sure that all agents can't have the last bite of food.
            # At that point, food is split evenly
            over_eat = self.food_healths[:, 0] < eat_per_food
            eat[:, over_eat] *= (self.food_healths[over_eat, 0] / eat_per_food[over_eat])
            eat_per_food = np.sum(eat, 0)
            self.eat_per_food = eat_per_food[:, None]

            # Update food healths and sizes
            self.food_healths -= eat_per_food[:, None]
            health_diff = eat_per_food[:, None]
            size_diff = health_diff * (self.max_food_size / self.max_food_health)
            size = self.unwrapped.sim.model.site_size[self.food_site_ids] - size_diff
            size = np.maximum(0, size)
            self.unwrapped.sim.model.site_size[self.food_site_ids] = size

            self.food_healths[self.respawn_counters == self.respawn_time] = self.max_food_health
            self.unwrapped.sim.model.site_size[self.food_site_ids[self.respawn_counters == self.respawn_time]] = self.max_food_size
            self.respawn_counters[self.food_healths[:, 0] == 0] += 1
            self.respawn_counters[self.food_healths[:, 0] != 0] = 0

            assert np.all(self.food_healths >= 0), \
                f"There is a food health below 0: {self.food_healths}"

            # calculate food reward
            if self.food_rew_type == 'selfish':
                food_rew = np.sum(eat, axis=1)
            elif self.food_rew_type == 'joint_mean':
                food_rew = np.sum(eat, axis=1)
                team_index = self.metadata['team_index']
                for team_index_number in np.unique(team_index):
                    food_rew[team_index == team_index_number] = np.mean(food_rew[team_index == team_index_number])
            else:
                raise ValueError(f"Food reward type {self.food_rew_type} unknown.")
        else:
            food_rew = 0.0

        info['agents_eat'] = eat
        rew += food_rew * self.curr_reward_scale
        return self.observation(obs), rew, done, info