ppo_ewma/roller.py (143 lines of code) (raw):

from collections import defaultdict import numpy as np import torch as th from . import torch_util as tu from .tree_util import tree_map from .vec_monitor2 import VecMonitor2 class Roller: def __init__( self, *, venv: "(VecEnv)", act_fn: "ob, state_in, first -> action, state_out, dict", initial_state: "RNN state", keep_buf: "number of episode stats to keep in rolling buffer" = 100, keep_sep_eps: "keep buffer of per-env episodes in VecMonitor2" = False, keep_non_rolling: "also keep a non-rolling buffer of episode stats" = False, keep_cost: "keep per step costs and add to segment" = False, ): """ All outputs from public methods are torch arrays on default device """ self._act_fn = act_fn if not isinstance(venv, VecMonitor2): venv = VecMonitor2( venv, keep_buf=keep_buf, keep_sep_eps=keep_sep_eps, keep_non_rolling=keep_non_rolling, ) self._venv = venv self._step_count = 0 self._state = initial_state self._infos = None self._keep_cost = keep_cost self.has_non_rolling_eps = keep_non_rolling @property def interact_count(self) -> int: return self.step_count * self._venv.num @property def step_count(self) -> int: return self._step_count @property def episode_count(self) -> int: return self._venv.epcount @property def recent_episodes(self) -> list: return self._venv.ep_buf.copy() @property def recent_eplens(self) -> list: return [ep.len for ep in self._venv.ep_buf] @property def recent_eprets(self) -> list: return [ep.ret for ep in self._venv.ep_buf] @property def recent_epinfos(self) -> list: return [ep.info for ep in self._venv.ep_buf] @property def per_env_episodes(self) -> list: return self._venv.per_env_buf @property def non_rolling_eplens(self) -> list: if self._venv.non_rolling_buf is None: return None return [ep.len for ep in self._venv.non_rolling_buf] @property def non_rolling_eprets(self) -> list: if self._venv.non_rolling_buf is None: return None return [ep.ret for ep in self._venv.non_rolling_buf] @property def non_rolling_epinfos(self) -> list: if self._venv.non_rolling_buf is None: return None return [ep.info for ep in self._venv.non_rolling_buf] def clear_episode_bufs(self): self._venv.clear_episode_bufs() def clear_per_env_episode_buf(self): self._venv.clear_per_env_episode_buf() def clear_non_rolling_episode_buf(self): self._venv.clear_non_rolling_episode_buf() @staticmethod def singles_to_multi(single_steps) -> dict: """ Stack single-step dicts into arrays with leading axes (batch, time) """ out = defaultdict(list) for d in single_steps: for (k, v) in d.items(): out[k].append(v) # TODO stack def toarr(xs): if isinstance(xs[0], dict): return {k: toarr([x[k] for x in xs]) for k in xs[0].keys()} if not tu.allsame([x.dtype for x in xs]): raise ValueError( f"Timesteps produced data of different dtypes: {set([x.dtype for x in xs])}" ) if isinstance(xs[0], th.Tensor): return th.stack(xs, dim=1).to(device=tu.dev()) elif isinstance(xs[0], np.ndarray): arr = np.stack(xs, axis=1) return tu.np2th(arr) else: raise NotImplementedError return {k: toarr(v) for (k, v) in out.items()} def multi_step(self, nstep, **act_kwargs) -> dict: """ step vectorized environment nstep times, return results final flag specifies if the final reward, observation, and first should be included in the segment (default: False) """ if self._venv.num == 0: self._step_count += nstep return {} state_in = self.get_state() singles = [self.single_step(**act_kwargs) for i in range(nstep)] out = self.singles_to_multi(singles) out["state_in"] = state_in finalrew, out["finalob"], out["finalfirst"] = tree_map( tu.np2th, self._venv.observe() ) out["finalstate"] = self.get_state() out["reward"] = th.cat([out["lastrew"][:, 1:], finalrew[:, None]], dim=1) if self._keep_cost: out["finalcost"] = tu.np2th( np.array([i.get("cost", 0.0) for i in self._venv.get_info()]) ) out["cost"] = th.cat( [out["lastcost"][:, 1:], out["finalcost"][:, None]], dim=1 ) del out["lastrew"] return out def single_step(self, **act_kwargs) -> dict: """ step vectorized environment once, return results """ out = {} lastrew, ob, first = tree_map(tu.np2th, self._venv.observe()) if self._keep_cost: out.update( lastcost=tu.np2th( np.array([i.get("cost", 0.0) for i in self._venv.get_info()]) ) ) ac, newstate, other_outs = self._act_fn( ob=ob, first=first, state_in=self._state, **act_kwargs ) self._state = newstate out.update(lastrew=lastrew, ob=ob, first=first, ac=ac) self._venv.act(tree_map(tu.th2np, ac)) for (k, v) in other_outs.items(): out[k] = v self._step_count += 1 return out def get_state(self): return self._state def observe(self): return self._venv.observe()