in petastorm/pytorch.py [0:0]
def __iter__(self):
"""
The Data Loader iterator stops the for-loop when num_epochs is reached.
"""
if self._in_iter:
raise RuntimeError("InMemBatchedDataLoader couldn't be used multiple times, please\
specify total number of epochs using num_epochs in constructor.")
self._in_iter = True
for epoch in range(self._num_epochs):
size = len(self._buffer[0])
if self._shuffle:
# Deterministically shuffle based on seed and current epoch id.
g = torch.Generator()
g.manual_seed(self._seed + epoch)
indices = torch.randperm(size, generator=g).tolist()
else:
indices = list(range(size))
# Sample batches
for i in range(0, len(indices), self._batch_size):
idx = indices[i:i+self._batch_size]
batch = [v[idx] for v in self._buffer]
size -= len(batch[0])
batch = dict(zip(self._keys, batch))
yield batch
assert size == 0