in mae_envs/viewer/policy_viewer.py [0:0]
def run(self):
if self.duration is not None:
self.end_time = time.time() + self.duration
self.total_rew_avg = 0.0
self.n_episodes = 0
while self.duration is None or time.time() < self.end_time:
if len(self.policies) == 1:
action, _ = self.policies[0].act(self.ob)
else:
self.ob = splitobs(self.ob, keepdims=False)
ob_policy_idx = np.split(np.arange(len(self.ob)), len(self.policies))
actions = []
for i, policy in enumerate(self.policies):
inp = itemgetter(*ob_policy_idx[i])(self.ob)
inp = listdict2dictnp([inp] if ob_policy_idx[i].shape[0] == 1 else inp)
ac, info = policy.act(inp)
actions.append(ac)
action = listdict2dictnp(actions, keepdims=True)
self.ob, rew, done, env_info = self.env.step(action)
self.total_rew += rew
if done or env_info.get('discard_episode', False):
self.reset_increment()
if self.display_window:
self.add_overlay(const.GRID_TOPRIGHT, "Reset env; (current seed: {})".format(self.seed), "N - next / P - previous ")
self.add_overlay(const.GRID_TOPRIGHT, "Reward", str(self.total_rew))
if hasattr(self.env.unwrapped, "viewer_stats"):
for k, v in self.env.unwrapped.viewer_stats.items():
self.add_overlay(const.GRID_TOPRIGHT, k, str(v))
self.env.render()