in data/data_utils.py [0:0]
def synchronized_dataloader_step(train_loader, is_dist):
"""
Create a synchronized iterator that handles uneven data distribution in DDP.
All ranks will stop when the first rank runs out of data.
This happens because when packing a presharded dataset, a rank might have less groups than the others.
"""
if not is_dist:
# For single GPU, we don't need synchronization.
for batch in train_loader:
yield batch
return
# For DDP, we need synchronization.
train_iter = iter(train_loader)
while True:
try:
batch = next(train_iter)
has_data = torch.tensor(1, device=torch.cuda.current_device())
except StopIteration:
batch = None
has_data = torch.tensor(0, device=torch.cuda.current_device())
# We synchronize across all ranks. If any rank is out of data, all ranks stop.
dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
if has_data.item() == 0:
# At least one rank is out of data. All ranks should stop.
break
yield batch
return None