def read_batch()

in point_e/evals/npz_stream.py [0:0]


    def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]:
        batches = [r.read_batch(batch_size) for r in self.readers]
        any_none = any(x is None for x in batches)
        all_none = all(x is None for x in batches)
        if any_none != all_none:
            raise RuntimeError("different keys had different numbers of elements")
        if any_none:
            return None
        if any(len(x) != len(batches[0]) for x in batches):
            raise RuntimeError("different keys had different numbers of elements")
        return dict(zip(self.keys, batches))