in ppo_ewma/torch_util.py [0:0]
def batch_len(batch):
"""
Given nested dict of arrays with same batchsize, return this batchsize
"""
flatlist, _ = tree_util.tree_flatten(batch)
if len(flatlist) < 1:
return 0
b = flatlist[0].shape[0]
assert all(
arr.shape[0] == b for arr in flatlist if th.is_tensor(arr)
), "Not all arrays have same batchsize!"
return b