in src/chug/wds/shardlists.py [0:0]
def __iter__(self):
"""Return an iterator over the shards."""
urls = self.urls.copy()
# Set epoch
if isinstance(self.interval, SharedCount):
interval = self.interval.get_value()
else:
# NOTE: this is interval tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.interval += 1
interval = self.interval
if self.seed is not None:
# Shuffle with the same seed across all nodes/workers in each interval or super interval
if self.num_sub_intervals is None:
seed = self.seed + interval
else:
# Keep shuffling consistent across the super epochs
seed = self.seed + (interval // self.num_sub_intervals)
random.Random(seed).shuffle(urls)
# Restrict to shards in the sub epoch if needed
if self.num_sub_intervals is not None:
urls = urls[interval % self.num_sub_intervals::self.num_sub_intervals]
# Yield shards
for url in urls:
yield dict(url=url)