in petastorm/pytorch.py [0:0]
def _iter_impl(self):
"""
The Data Loader iterator stops the for-loop when reader runs out of samples.
"""
# As we iterate over incoming samples, we are going to store them in `self._batch_acc`, until we have a batch of
# the requested batch_size ready.
keys = None
if self.shuffling_queue_capacity > 0:
# We can not know what is the reasonable number to use for the extra capacity, so we set a huge number
# and give up on the unbound growth protection mechanism.
min_after_dequeue = self.shuffling_queue_capacity - 1
self._shuffling_buffer = RandomShufflingBuffer(self.shuffling_queue_capacity,
min_after_retrieve=min_after_dequeue,
extra_capacity=100000000)
else:
self._shuffling_buffer = NoopShufflingBuffer()
for row in self.reader:
# Default collate does not work nicely on namedtuples and treat them as lists
# Using dict will result in the yielded structures being dicts as well
row_as_dict = row._asdict()
keys = row_as_dict.keys()
# Promote some types that are incompatible with pytorch to be pytorch friendly.
_sanitize_pytorch_types(row_as_dict)
# Add rows to shuffling buffer
if not self.reader.is_batched_reader:
self._shuffling_buffer.add_many([row_as_dict])
else:
# Transposition:
# row_as_dict: {'a': [1,2,3], 'b':[4,5,6]}
# row_group_as_tuple: [(1, 4), (2, 5), (3, 6)]
# The order within a tuple is defined by key order in 'keys'
row_group_as_tuple = list(zip(*(row_as_dict[k] for k in keys)))
# Adding data as 'row-by-row' into a shuffling buffer. This is a pretty
# slow implementation though. Probably can comeup with a faster way to shuffle,
# perhaps at the expense of a larger memory consumption...
self._shuffling_buffer.add_many(row_group_as_tuple)
# _yield_batches will emit as much batches as are allowed by the shuffling_buffer (RandomShufflingBuffer
# will avoid underflowing below a certain number of samples to guarantee some samples decorrelation)
for batch in self._yield_batches(keys):
yield batch
# Once reader can not read new rows, we might still have a bunch of rows waiting in the shuffling buffer.
# Telling shuffling buffer that we are finished allows to deplete the buffer completely, regardless its
# min_after_dequeue setting.
self._shuffling_buffer.finish()
for batch in self._yield_batches(keys):
yield batch
# Yield the last and partial batch
if self._batch_acc:
yield self.collate_fn(self._batch_acc)