in ppo_ewma/roller.py [0:0]
def singles_to_multi(single_steps) -> dict:
"""
Stack single-step dicts into arrays with leading axes (batch, time)
"""
out = defaultdict(list)
for d in single_steps:
for (k, v) in d.items():
out[k].append(v)
# TODO stack
def toarr(xs):
if isinstance(xs[0], dict):
return {k: toarr([x[k] for x in xs]) for k in xs[0].keys()}
if not tu.allsame([x.dtype for x in xs]):
raise ValueError(
f"Timesteps produced data of different dtypes: {set([x.dtype for x in xs])}"
)
if isinstance(xs[0], th.Tensor):
return th.stack(xs, dim=1).to(device=tu.dev())
elif isinstance(xs[0], np.ndarray):
arr = np.stack(xs, axis=1)
return tu.np2th(arr)
else:
raise NotImplementedError
return {k: toarr(v) for (k, v) in out.items()}