in point_e/evals/npz_stream.py [0:0]
def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]:
batches = [r.read_batch(batch_size) for r in self.readers]
any_none = any(x is None for x in batches)
all_none = all(x is None for x in batches)
if any_none != all_none:
raise RuntimeError("different keys had different numbers of elements")
if any_none:
return None
if any(len(x) != len(batches[0]) for x in batches):
raise RuntimeError("different keys had different numbers of elements")
return dict(zip(self.keys, batches))