mae_envs/viewer/policy_viewer.py (100 lines of code) (raw):
#!/usr/bin/env python
import time
import glfw
import numpy as np
from operator import itemgetter
from mujoco_py import const, MjViewer
from mujoco_worldgen.util.types import store_args
from ma_policy.util import listdict2dictnp
def splitobs(obs, keepdims=True):
'''
Split obs into list of single agent obs.
Args:
obs: dictionary of numpy arrays where first dim in each array is agent dim
'''
n_agents = obs[list(obs.keys())[0]].shape[0]
return [{k: v[[i]] if keepdims else v[i] for k, v in obs.items()} for i in range(n_agents)]
class PolicyViewer(MjViewer):
'''
PolicyViewer runs a policy with an environment and optionally displays it.
env - environment to run policy in
policy - policy object to run
display_window - if true, show the graphical viewer
seed - environment seed to view
duration - time in seconds to run the policy, run forever if duration=None
'''
@store_args
def __init__(self, env, policies, display_window=True, seed=None, duration=None):
if seed is None:
self.seed = env.seed()[0]
else:
self.seed = seed
env.seed(seed)
self.total_rew = 0.0
self.ob = env.reset()
for policy in self.policies:
policy.reset()
assert env.metadata['n_actors'] % len(policies) == 0
if hasattr(env, "reset_goal"):
self.goal = env.reset_goal()
super().__init__(self.env.unwrapped.sim)
# TO DO: remove circular dependency on viewer object. It looks fishy.
self.env.unwrapped.viewer = self
if self.render and self.display_window:
self.env.render()
def key_callback(self, window, key, scancode, action, mods):
super().key_callback(window, key, scancode, action, mods)
# Trigger on keyup only:
if action != glfw.RELEASE:
return
# Increment experiment seed
if key == glfw.KEY_N:
self.reset_increment()
# Decrement experiment trial
elif key == glfw.KEY_P:
print("Pressed P")
self.seed = max(self.seed - 1, 0)
self.env.seed(self.seed)
self.ob = self.env.reset()
for policy in self.policies:
policy.reset()
if hasattr(self.env, "reset_goal"):
self.goal = self.env.reset_goal()
self.update_sim(self.env.unwrapped.sim)
def run(self):
if self.duration is not None:
self.end_time = time.time() + self.duration
self.total_rew_avg = 0.0
self.n_episodes = 0
while self.duration is None or time.time() < self.end_time:
if len(self.policies) == 1:
action, _ = self.policies[0].act(self.ob)
else:
self.ob = splitobs(self.ob, keepdims=False)
ob_policy_idx = np.split(np.arange(len(self.ob)), len(self.policies))
actions = []
for i, policy in enumerate(self.policies):
inp = itemgetter(*ob_policy_idx[i])(self.ob)
inp = listdict2dictnp([inp] if ob_policy_idx[i].shape[0] == 1 else inp)
ac, info = policy.act(inp)
actions.append(ac)
action = listdict2dictnp(actions, keepdims=True)
self.ob, rew, done, env_info = self.env.step(action)
self.total_rew += rew
if done or env_info.get('discard_episode', False):
self.reset_increment()
if self.display_window:
self.add_overlay(const.GRID_TOPRIGHT, "Reset env; (current seed: {})".format(self.seed), "N - next / P - previous ")
self.add_overlay(const.GRID_TOPRIGHT, "Reward", str(self.total_rew))
if hasattr(self.env.unwrapped, "viewer_stats"):
for k, v in self.env.unwrapped.viewer_stats.items():
self.add_overlay(const.GRID_TOPRIGHT, k, str(v))
self.env.render()
def reset_increment(self):
self.total_rew_avg = (self.n_episodes * self.total_rew_avg + self.total_rew) / (self.n_episodes + 1)
self.n_episodes += 1
print(f"Reward: {self.total_rew} (rolling average: {self.total_rew_avg})")
self.total_rew = 0.0
self.seed += 1
self.env.seed(self.seed)
self.ob = self.env.reset()
for policy in self.policies:
policy.reset()
if hasattr(self.env, "reset_goal"):
self.goal = self.env.reset_goal()
self.update_sim(self.env.unwrapped.sim)