ppo_ewma/vec_monitor2.py (79 lines of code) (raw):
import time
from collections import deque, namedtuple
import numpy as np
import gym3
Episode = namedtuple("Episode", ["ret", "len", "time", "info"])
class PostActProcessing(gym3.Wrapper):
"""
Call process() after each action, except possibly possibly the last
one which you never called observe for.
"""
def __init__(self, env):
super().__init__(env)
self.need_process = False
def process_if_needed(self):
if self.need_process:
self.process()
self.need_process = False
def act(self, ac):
self.process_if_needed()
self.env.act(ac)
self.need_process = True
def observe(self):
self.process_if_needed()
return self.env.observe()
def process(self):
raise NotImplementedError
class VecMonitor2(PostActProcessing):
def __init__(
self,
venv: "(gym3.Env)",
keep_buf: "(int) how many returns/lengths/infos to keep" = 0,
keep_sep_eps: "keep separate buffer per env" = False,
keep_non_rolling: "keep separate buffer that must be explicitly cleared" = False,
):
"""
use n_per_env if you want to keep sep
"""
super().__init__(venv)
self.eprets = None
self.eplens = None
self.epcount = 0
self.tstart = time.time()
if keep_buf:
self.ep_buf = deque([], maxlen=keep_buf)
else:
self.ep_buf = None
if keep_sep_eps:
self.per_env_buf = [[] for _ in range(self.num)]
else:
self.per_env_buf = None
if keep_non_rolling:
self.non_rolling_buf = deque([])
else:
self.non_rolling_buf = None
self.eprets = np.zeros(self.num, "f")
self.eplens = np.zeros(self.num, "i")
def process(self):
lastrews, _obs, firsts = self.env.observe()
infos = self.env.get_info()
self.eprets += lastrews
self.eplens += 1
for i in range(self.num):
if firsts[i]:
timefromstart = round(time.time() - self.tstart, 6)
ep = Episode(self.eprets[i], self.eplens[i], timefromstart, infos[i])
if self.ep_buf is not None:
self.ep_buf.append(ep)
if self.per_env_buf is not None:
self.per_env_buf[i].append(ep)
if self.non_rolling_buf is not None:
self.non_rolling_buf.append(ep)
self.epcount += 1
self.eprets[i] = 0
self.eplens[i] = 0
def clear_episode_bufs(self):
if self.ep_buf:
self.ep_buf.clear()
self.clear_per_env_episode_buf()
self.clear_non_rolling_episode_buf()
def clear_per_env_episode_buf(self):
if self.per_env_buf:
for i in range(self.num):
self.per_env_buf[i].clear()
def clear_non_rolling_episode_buf(self):
if self.non_rolling_buf:
self.non_rolling_buf.clear()