in opacus/data_loader.py [0:0]
def switch_generator(*, data_loader: DataLoader, generator):
"""
Creates new instance of a ``DataLoader``, with the exact same behaviour of the
provided data loader, except for the source of randomness.
Typically used to enhance a user-provided data loader object with cryptographically
secure random number generator
Args:
data_loader: Any ``DataLoader`` object
generator: Random number generator object
Returns:
New ``DataLoader`` object with the exact same behaviour as the input data loader,
except for the source of randomness.
"""
batch_sampler = data_loader.batch_sampler
if batch_sampler is None or not _is_supported_batch_sampler(batch_sampler):
raise ValueError(
"Non-batch processing is not supported: Opacus always assumes one of the input dimensions to be batch dimension."
)
if isinstance(batch_sampler, BatchSampler):
if not hasattr(batch_sampler.sampler, "generator"):
raise ValueError(
"Target sampler doesn't have generator attribute: nothing to switch"
)
batch_sampler.sampler.generator = generator
else:
batch_sampler.generator = generator
return DataLoader(
dataset=data_loader.dataset,
batch_sampler=batch_sampler,
num_workers=data_loader.num_workers,
collate_fn=data_loader.collate_fn,
pin_memory=data_loader.pin_memory,
drop_last=data_loader.drop_last,
timeout=data_loader.timeout,
worker_init_fn=data_loader.worker_init_fn,
multiprocessing_context=data_loader.multiprocessing_context,
generator=generator,
prefetch_factor=data_loader.prefetch_factor,
persistent_workers=data_loader.persistent_workers,
)