in fastmri/pl_modules/data_module.py [0:0]
def worker_init_fn(worker_id):
"""Handle random seeding for all mask_func."""
worker_info = torch.utils.data.get_worker_info()
data: Union[
SliceDataset, CombinedSliceDataset
] = worker_info.dataset # pylint: disable=no-member
# Check if we are using DDP
is_ddp = False
if torch.distributed.is_available():
if torch.distributed.is_initialized():
is_ddp = True
# for NumPy random seed we need it to be in this range
base_seed = worker_info.seed # pylint: disable=no-member
if isinstance(data, CombinedSliceDataset):
for i, dataset in enumerate(data.datasets):
if dataset.transform.mask_func is not None:
if (
is_ddp
): # DDP training: unique seed is determined by worker, device, dataset
seed_i = (
base_seed
- worker_info.id
+ torch.distributed.get_rank()
* (worker_info.num_workers * len(data.datasets))
+ worker_info.id * len(data.datasets)
+ i
)
else:
seed_i = (
base_seed
- worker_info.id
+ worker_info.id * len(data.datasets)
+ i
)
dataset.transform.mask_func.rng.seed(seed_i % (2 ** 32 - 1))
elif data.transform.mask_func is not None:
if is_ddp: # DDP training: unique seed is determined by worker and device
seed = base_seed + torch.distributed.get_rank() * worker_info.num_workers
else:
seed = base_seed
data.transform.mask_func.rng.seed(seed % (2 ** 32 - 1))