in iterators.py [0:0]
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
def shuffle_batches(batches, seed):
with numpy_seed(seed):
np.random.shuffle(batches)
return batches
if self._supports_prefetch:
batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
))
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
))
if offset > 0 and offset >= len(batches):
return None
if self.num_workers > 0:
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
# Create data loader
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
timeout=self.timeout,
)
# Wrap with a BufferedIterator if needed
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)
# Wrap with CoutingIterator
itr = CountingIterator(itr, start=offset)
return itr