mae_envs/viewer/env_viewer.py (117 lines of code) (raw):
import numpy as np
import time
from mujoco_py import const, MjViewer
import glfw
from gym.spaces import Box, MultiDiscrete, Discrete
class EnvViewer(MjViewer):
def __init__(self, env):
self.env = env
self.elapsed = [0]
self.seed = self.env.seed()
super().__init__(self.env.unwrapped.sim)
self.n_agents = self.env.metadata['n_actors']
self.action_types = list(self.env.action_space.spaces.keys())
self.num_action_types = len(self.env.action_space.spaces)
self.num_action = self.num_actions(self.env.action_space)
self.agent_mod_index = 0
self.action_mod_index = 0
self.action_type_mod_index = 0
self.action = self.zero_action(self.env.action_space)
self.env_reset()
def num_actions(self, ac_space):
n_actions = []
for k, tuple_space in ac_space.spaces.items():
s = tuple_space.spaces[0]
if isinstance(s, Box):
n_actions.append(s.shape[0])
elif isinstance(s, Discrete):
n_actions.append(1)
elif isinstance(s, MultiDiscrete):
n_actions.append(s.nvec.shape[0])
else:
raise NotImplementedError(f"not NotImplementedError")
return n_actions
def zero_action(self, ac_space):
ac = {}
for k, space in ac_space.spaces.items():
if isinstance(space.spaces[0], Box):
ac[k] = np.zeros_like(space.sample())
elif isinstance(space.spaces[0], Discrete):
ac[k] = np.ones_like(space.sample()) * (space.spaces[0].n // 2)
elif isinstance(space.spaces[0], MultiDiscrete):
ac[k] = np.ones_like(space.sample(), dtype=int) * (space.spaces[0].nvec // 2)
else:
raise NotImplementedError("MultiDiscrete not NotImplementedError")
# return action_space.nvec // 2 # assume middle element is "no action" action
return ac
def env_reset(self):
start = time.time()
# get the seed before calling env.reset(), so we display the one
# that was used for the reset.
self.seed = self.env.seed()
self.env.reset()
self.elapsed.append(time.time() - start)
self.update_sim(self.env.unwrapped.sim)
def key_callback(self, window, key, scancode, action, mods):
# Trigger on keyup only:
if action != glfw.RELEASE:
return
if key == glfw.KEY_ESCAPE:
self.env.close()
# Increment experiment seed
elif key == glfw.KEY_N:
self.seed[0] += 1
self.env.seed(self.seed)
self.env_reset()
self.action = self.zero_action(self.env.action_space)
# Decrement experiment trial
elif key == glfw.KEY_P:
self.seed = [max(self.seed[0] - 1, 0)]
self.env.seed(self.seed)
self.env_reset()
self.action = self.zero_action(self.env.action_space)
current_action_space = self.env.action_space.spaces[self.action_types[self.action_type_mod_index]].spaces[0]
if key == glfw.KEY_A:
if isinstance(current_action_space, Box):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] -= 0.05
elif isinstance(current_action_space, Discrete):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index] = \
(self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index] - 1) % current_action_space.n
elif isinstance(current_action_space, MultiDiscrete):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] = \
(self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] - 1) \
% current_action_space.nvec[self.action_mod_index]
elif key == glfw.KEY_Z:
if isinstance(current_action_space, Box):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] += 0.05
elif isinstance(current_action_space, Discrete):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index] = \
(self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index] + 1) % current_action_space.n
elif isinstance(current_action_space, MultiDiscrete):
self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] = \
(self.action[self.action_types[self.action_type_mod_index]][self.agent_mod_index][self.action_mod_index] + 1) \
% current_action_space.nvec[self.action_mod_index]
elif key == glfw.KEY_K:
self.action_mod_index = (self.action_mod_index + 1) % self.num_action[self.action_type_mod_index]
elif key == glfw.KEY_J:
self.action_mod_index = (self.action_mod_index - 1) % self.num_action[self.action_type_mod_index]
elif key == glfw.KEY_Y:
self.agent_mod_index = (self.agent_mod_index + 1) % self.n_agents
elif key == glfw.KEY_U:
self.agent_mod_index = (self.agent_mod_index - 1) % self.n_agents
elif key == glfw.KEY_G:
self.action_type_mod_index = (self.action_type_mod_index + 1) % self.num_action_types
self.action_mod_index = 0
elif key == glfw.KEY_B:
self.action_type_mod_index = (self.action_type_mod_index - 1) % self.num_action_types
self.action_mod_index = 0
super().key_callback(window, key, scancode, action, mods)
def run(self, once=False):
while True:
_, _, _, env_info = self.env.step(self.action)
if env_info.get('discard_episode', False):
self.env.reset()
self.add_overlay(const.GRID_TOPRIGHT, "Reset env; (current seed: {})".format(self.seed), "N - next / P - previous ")
self.add_overlay(const.GRID_TOPRIGHT, "Apply action", "A (-0.05) / Z (+0.05)")
self.add_overlay(const.GRID_TOPRIGHT, "on agent index %d out %d" % (self.agent_mod_index, self.n_agents), "Y / U")
self.add_overlay(const.GRID_TOPRIGHT, f"on action type {self.action_types[self.action_type_mod_index]}", "G / B")
self.add_overlay(const.GRID_TOPRIGHT, "on action index %d out %d" % (self.action_mod_index, self.num_action[self.action_type_mod_index]), "J / K")
self.add_overlay(const.GRID_BOTTOMRIGHT, "Reset took", "%.2f sec." % (sum(self.elapsed) / len(self.elapsed)))
self.add_overlay(const.GRID_BOTTOMRIGHT, "Action", str(self.action))
self.render()
if once:
return