data/data_utils.py (20 lines of code) (raw):

import torch import torch.distributed as dist def synchronized_dataloader_step(train_loader, is_dist): """ Create a synchronized iterator that handles uneven data distribution in DDP. All ranks will stop when the first rank runs out of data. This happens because when packing a presharded dataset, a rank might have less groups than the others. """ if not is_dist: # For single GPU, we don't need synchronization. for batch in train_loader: yield batch return # For DDP, we need synchronization. train_iter = iter(train_loader) while True: try: batch = next(train_iter) has_data = torch.tensor(1, device=torch.cuda.current_device()) except StopIteration: batch = None has_data = torch.tensor(0, device=torch.cuda.current_device()) # We synchronize across all ranks. If any rank is out of data, all ranks stop. dist.all_reduce(has_data, op=dist.ReduceOp.MIN) if has_data.item() == 0: # At least one rank is out of data. All ranks should stop. break yield batch return None