mujoco_worldgen/util/envs/env_viewer.py (68 lines of code) (raw):
import numpy as np
import time
from mujoco_py import const, MjViewer, ignore_mujoco_warnings
import glfw
from gym.spaces import Box
from gym.spaces import MultiDiscrete
class EnvViewer(MjViewer):
def __init__(self, env):
self.env = env
self.elapsed = [0]
self.env.reset()
self.seed = self.env.seed()
super().__init__(self.env.unwrapped.sim)
self.num_action = self.env.action_space.shape[0]
self.action_mod_index = 0
self.action = self.zero_action(self.env.action_space)
def zero_action(self, action_space):
if isinstance(action_space, Box):
return np.zeros(action_space.shape[0])
elif isinstance(action_space, MultiDiscrete):
return action_space.nvec // 2 # assume middle element is "no action" action
def env_reset(self):
start = time.time()
# get the seed before calling env.reset(), so we display the one
# that was used for the reset.
self.seed = self.env.seed()
self.env.reset()
self.elapsed.append(time.time() - start)
self.update_sim(self.env.unwrapped.sim)
def key_callback(self, window, key, scancode, action, mods):
# Trigger on keyup only:
if action != glfw.RELEASE:
return
if key == glfw.KEY_ESCAPE:
self.env.close()
# Increment experiment seed
elif key == glfw.KEY_N:
self.seed[0] += 1
self.env.seed(self.seed)
self.env_reset()
self.action = self.zero_action(self.env.action_space)
# Decrement experiment trial
elif key == glfw.KEY_P:
self.seed = [max(self.seed[0] - 1, 0)]
self.env.seed(self.seed)
self.env_reset()
self.action = self.zero_action(self.env.action_space)
if key == glfw.KEY_A:
if isinstance(self.env.action_space, Box):
self.action[self.action_mod_index] -= 0.05
elif key == glfw.KEY_Z:
if isinstance(self.env.action_space, Box):
self.action[self.action_mod_index] += 0.05
elif key == glfw.KEY_K:
self.action_mod_index = (self.action_mod_index + 1) % self.num_action
elif key == glfw.KEY_J:
self.action_mod_index = (self.action_mod_index - 1) % self.num_action
super().key_callback(window, key, scancode, action, mods)
def render(self):
super().render()
# Display applied external forces.
self.vopt.flags[8] = 1
def run(self, once=False):
while True:
with ignore_mujoco_warnings():
self.env.step(self.action)
self.add_overlay(const.GRID_TOPRIGHT, "Reset env; (current seed: {})".format(self.seed), "N - next / P - previous ")
self.add_overlay(const.GRID_TOPRIGHT, "Apply action", "A (-0.05) / Z (+0.05)")
self.add_overlay(const.GRID_TOPRIGHT, "on action index %d out %d" % (self.action_mod_index, self.num_action), "J / K")
self.add_overlay(const.GRID_BOTTOMRIGHT, "Reset took", "%.2f sec." % (sum(self.elapsed) / len(self.elapsed)))
self.add_overlay(const.GRID_BOTTOMRIGHT, "Action", str(self.action))
self.render()
if once:
return