in petastorm/pytorch.py [0:0]
def _iter_impl(self):
"""
The Data Loader iterator stops the for-loop when reader runs out of samples.
"""
# As we iterate over incoming samples, we are going to store them in `self._batch_acc`, until we have a batch of
# the requested batch_size ready.
keys = None
if self.shuffling_queue_capacity > 0:
# We can not know what is the reasonable number to use for the extra capacity, so we set a huge number
# and give up on the unbound growth protection mechanism.
# To keep the same behavior as DataLoader, we need to increase the shuffling_queue_capacity
min_after_dequeue = self.shuffling_queue_capacity - 1
shuffling_queue_capacity = min_after_dequeue + self.batch_size
self._shuffling_buffer = BatchedRandomShufflingBuffer(
shuffling_queue_capacity,
min_after_retrieve=min_after_dequeue,
extra_capacity=100000000,
batch_size=self.batch_size
)
else:
self._shuffling_buffer = BatchedNoopShufflingBuffer(batch_size=self.batch_size)
for row in self.reader:
# Default collate does not work nicely on namedtuples and treat them as lists
# Using dict will result in the yielded structures being dicts as well
row_as_dict = row._asdict()
keys = row_as_dict.keys()
# Promote some types that are incompatible with pytorch to be pytorch friendly.
_sanitize_pytorch_types(row_as_dict)
# Add rows to shuffling buffer
for k, v in row_as_dict.items():
if not self.reader.is_batched_reader:
row_as_dict[k] = self.transform_fn([v])
else:
row_as_dict[k] = self.transform_fn(v)
self._shuffling_buffer.add_many(row_as_dict.values())
# _yield_batches will emit as much batches as are allowed by the shuffling_buffer (RandomShufflingBuffer
# will avoid underflowing below a certain number of samples to guarantee some samples decorrelation)
for batch in self._yield_batches(keys):
yield batch
# Once reader can not read new rows, we might still have a bunch of rows waiting in the shuffling buffer.
# Telling shuffling buffer that we are finished allows to deplete the buffer completely, regardless its
# min_after_dequeue setting.
self._shuffling_buffer.finish()
for batch in self._yield_batches(keys):
yield batch