def batch_len()

in ppo_ewma/torch_util.py [0:0]


def batch_len(batch):
    """
    Given nested dict of arrays with same batchsize, return this batchsize
    """
    flatlist, _ = tree_util.tree_flatten(batch)
    if len(flatlist) < 1:
        return 0
    b = flatlist[0].shape[0]
    assert all(
        arr.shape[0] == b for arr in flatlist if th.is_tensor(arr)
    ), "Not all arrays have same batchsize!"
    return b