in point_e/evals/npz_stream.py [0:0]
def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]:
cur_batch = None
num_remaining = self.trunc_length
for path in self.paths:
if num_remaining is not None and num_remaining <= 0:
break
with open_npz_arrays(path, keys) as readers:
combined_reader = CombinedReader(keys, readers)
while num_remaining is None or num_remaining > 0:
read_bs = batch_size
if cur_batch is not None:
read_bs -= _dict_batch_size(cur_batch)
if num_remaining is not None:
read_bs = min(read_bs, num_remaining)
batch = combined_reader.read_batch(read_bs)
if batch is None:
break
if num_remaining is not None:
num_remaining -= _dict_batch_size(batch)
if cur_batch is None:
cur_batch = batch
else:
cur_batch = {
# pylint: disable=unsubscriptable-object
k: np.concatenate([cur_batch[k], v], axis=0)
for k, v in batch.items()
}
if _dict_batch_size(cur_batch) == batch_size:
yield cur_batch
cur_batch = None
if cur_batch is not None:
yield cur_batch