def minibatch_gen()

in ppo_ewma/minibatch_optimize.py [0:0]


def minibatch_gen(data, *, batch_size=None, nminibatch=None, forever=False):
    assert (batch_size is None) != (
        nminibatch is None
    ), "only one of batch_size or nminibatch should be specified"
    ntrain = tu.batch_len(data)
    if nminibatch is None:
        nminibatch = max(ntrain // batch_size, 1)
    if nminibatch > ntrain:
        assert (
            nminibatch % ntrain == 0
        ), "nminibatch must be a multiple of ntrain if it is larger"
        nchunks = nminibatch // ntrain
        nminibatch = ntrain
    else:
        nchunks = 1
    while True:
        for mbinds in th.chunk(th.randperm(ntrain), nminibatch):
            mb = tree_map(to_th_device, tu.tree_slice(data, mbinds))
            for chunknum in range(nchunks):
                # raises an IndexError if nminibatch is incompatible with num_envs and nstep
                yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb)
        if not forever:
            return