def singles_to_multi()

in ppo_ewma/roller.py [0:0]


    def singles_to_multi(single_steps) -> dict:
        """
        Stack single-step dicts into arrays with leading axes (batch, time)
        """
        out = defaultdict(list)
        for d in single_steps:
            for (k, v) in d.items():
                out[k].append(v)

        # TODO stack
        def toarr(xs):
            if isinstance(xs[0], dict):
                return {k: toarr([x[k] for x in xs]) for k in xs[0].keys()}
            if not tu.allsame([x.dtype for x in xs]):
                raise ValueError(
                    f"Timesteps produced data of different dtypes: {set([x.dtype for x in xs])}"
                )
            if isinstance(xs[0], th.Tensor):
                return th.stack(xs, dim=1).to(device=tu.dev())
            elif isinstance(xs[0], np.ndarray):
                arr = np.stack(xs, axis=1)
                return tu.np2th(arr)
            else:
                raise NotImplementedError

        return {k: toarr(v) for (k, v) in out.items()}