mae_envs/wrappers/prep_phase.py (78 lines of code) (raw):

import gym import numpy as np from copy import deepcopy from mae_envs.wrappers.util import update_obs_space class PreparationPhase(gym.Wrapper): ''' Rewards are switched off during preparation. Args: prep_fraction (float): Fraction of total time that is preparation time ''' def __init__(self, env, prep_fraction=.2): super().__init__(env) self.prep_fraction = prep_fraction self.prep_time = self.prep_fraction * self.unwrapped.horizon self.n_agents = self.metadata['n_agents'] self.step_counter = 0 self.observation_space = update_obs_space(self, {'prep_obs': [self.n_agents, 1]}) def reset(self): self.step_counter = 0 self.in_prep_phase = True return self.observation(self.env.reset()) def reward(self, reward): if self.in_prep_phase: reward = np.zeros_like(reward) return reward def observation(self, obs): obs['prep_obs'] = (np.ones((self.n_agents, 1)) * np.minimum(1.0, self.step_counter / (self.prep_time + 1e-5))) return obs def step(self, action): obs, rew, done, info = self.env.step(action) rew = self.reward(rew) self.step_counter += 1 self.in_prep_phase = self.step_counter < self.prep_time info['in_prep_phase'] = self.in_prep_phase return self.observation(obs), rew, done, info class NoActionsInPrepPhase(gym.Wrapper): '''Agents have all actions turned off during preparation phase. For MultiDiscrete and Discrete, assumes zero action is the rounded down middle action''' def __init__(self, env, agent_idxs): super().__init__(env) self.agent_idxs = np.array(agent_idxs) def reset(self): obs = self.env.reset() self.in_prep_phase = True return obs def step(self, action): obs, rew, done, info = self.env.step(self.action(action)) self.in_prep_phase = info['in_prep_phase'] return obs, rew, done, info def action(self, action): ac = deepcopy(action) if self.in_prep_phase: for k, space in self.action_space.spaces.items(): _space = space.spaces[0] if isinstance(_space, gym.spaces.MultiDiscrete): zero_ac = (_space.nvec - 1) // 2 elif isinstance(_space, gym.spaces.Discrete): zero_ac = (_space.n - 1) // 2 else: zero_ac = 0.0 ac[k][self.agent_idxs] = zero_ac return ac class MaskPrepPhaseAction(gym.Wrapper): ''' Masks a (binary) action during preparation phase ''' def __init__(self, env, action_key): super().__init__(env) self.action_key = action_key def reset(self): obs = self.env.reset() self.in_prep_phase = True return obs def step(self, action): action[self.action_key] = (action[self.action_key] * (1 - self.in_prep_phase)).astype(bool) obs, rew, done, info = self.env.step(action) self.in_prep_phase = info['in_prep_phase'] return obs, rew, done, info