def make_minibatches()

in ppo_ewma/ppg.py [0:0]


def make_minibatches(segs, mbsize):
    """
    Yield one epoch of minibatch over the dataset described by segs
    Each minibatch mixes data between different segs
    """
    if mbsize < 1:
        nchunks = int(round(1 / mbsize))
        assert th.isclose(
            th.tensor(1 / mbsize), th.tensor(float(nchunks))
        ), "mbsize must be an integer or 1 divided by an integer"
        mbsize = 1
    else:
        nchunks = 1
    nenv = tu.batch_len(segs[0])
    nseg = len(segs)
    envs_segs = th.tensor(list(itertools.product(range(nenv), range(nseg))))
    for perminds in th.randperm(len(envs_segs)).split(mbsize):
        esinds = envs_segs[perminds]
        mb = tu.tree_stack(
            [tu.tree_slice(segs[segind], envind) for (envind, segind) in esinds]
        )
        for chunknum in range(nchunks):
            # raises an IndexError if aux_mbsize < 1 is incompatible with nstep
            yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb)