in src/chug/wds/filters.py [0:0]
def run(self, src):
if isinstance(self.interval, SharedCount):
interval = self.interval.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.interval += 1
interval = self.interval
rng = random.Random()
if self.unique_worker:
# Use the PyTorch worker's seed, *different* across all nodes/workers
# but also deterministic if they are set consistently
seed = get_pytorch_worker_seed(interval, initial_seed=self.seed)
else:
# This seed to be deterministic AND the *same* across all nodes/workers in each epoch/interval
seed = self.seed + interval
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)