in ppo_ewma/minibatch_optimize.py [0:0]
def minibatch_gen(data, *, batch_size=None, nminibatch=None, forever=False):
assert (batch_size is None) != (
nminibatch is None
), "only one of batch_size or nminibatch should be specified"
ntrain = tu.batch_len(data)
if nminibatch is None:
nminibatch = max(ntrain // batch_size, 1)
if nminibatch > ntrain:
assert (
nminibatch % ntrain == 0
), "nminibatch must be a multiple of ntrain if it is larger"
nchunks = nminibatch // ntrain
nminibatch = ntrain
else:
nchunks = 1
while True:
for mbinds in th.chunk(th.randperm(ntrain), nminibatch):
mb = tree_map(to_th_device, tu.tree_slice(data, mbinds))
for chunknum in range(nchunks):
# raises an IndexError if nminibatch is incompatible with num_envs and nstep
yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb)
if not forever:
return