def stream()

in point_e/evals/npz_stream.py [0:0]


    def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]:
        cur_batch = None
        num_remaining = self.trunc_length
        for path in self.paths:
            if num_remaining is not None and num_remaining <= 0:
                break
            with open_npz_arrays(path, keys) as readers:
                combined_reader = CombinedReader(keys, readers)
                while num_remaining is None or num_remaining > 0:
                    read_bs = batch_size
                    if cur_batch is not None:
                        read_bs -= _dict_batch_size(cur_batch)
                    if num_remaining is not None:
                        read_bs = min(read_bs, num_remaining)

                    batch = combined_reader.read_batch(read_bs)
                    if batch is None:
                        break
                    if num_remaining is not None:
                        num_remaining -= _dict_batch_size(batch)
                    if cur_batch is None:
                        cur_batch = batch
                    else:
                        cur_batch = {
                            # pylint: disable=unsubscriptable-object
                            k: np.concatenate([cur_batch[k], v], axis=0)
                            for k, v in batch.items()
                        }
                    if _dict_batch_size(cur_batch) == batch_size:
                        yield cur_batch
                        cur_batch = None
        if cur_batch is not None:
            yield cur_batch