in src/chug/wds/shardlists.py [0:0]
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.interval, SharedCount):
interval = self.interval.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.interval += 1
interval = self.interval
if self.deterministic:
# reset seed w/ interval if deterministic
if self.worker_seed_fn is None:
# pytorch worker seed should be deterministic (per-worker)
# It is init by process seed, rank, & worker id
seed = get_pytorch_worker_seed(interval, initial_seed=self.seed)
else:
seed = self.worker_seed_fn() + interval
self.rng.seed(seed)
for _ in range(self.nshards):
if self.weights is None:
yield dict(url=self.rng.choice(self.urls))
else:
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])