in ppo_ewma/ppg.py [0:0]
def make_minibatches(segs, mbsize):
"""
Yield one epoch of minibatch over the dataset described by segs
Each minibatch mixes data between different segs
"""
if mbsize < 1:
nchunks = int(round(1 / mbsize))
assert th.isclose(
th.tensor(1 / mbsize), th.tensor(float(nchunks))
), "mbsize must be an integer or 1 divided by an integer"
mbsize = 1
else:
nchunks = 1
nenv = tu.batch_len(segs[0])
nseg = len(segs)
envs_segs = th.tensor(list(itertools.product(range(nenv), range(nseg))))
for perminds in th.randperm(len(envs_segs)).split(mbsize):
esinds = envs_segs[perminds]
mb = tu.tree_stack(
[tu.tree_slice(segs[segind], envind) for (envind, segind) in esinds]
)
for chunknum in range(nchunks):
# raises an IndexError if aux_mbsize < 1 is incompatible with nstep
yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb)