in community-content/pytorch_efficient_training/resnet_fsdp_wds.py [0:0]
def create_wds_dataloader(rank, args, mode):
"""Create webdataset dataset and dataloader."""
if mode == 'train':
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
data_path = args.train_data_path
data_size = args.train_data_size
batch_size_local = args.train_batch_size
batch_size_global = args.train_batch_size * args.gpus
# Since webdataset disallows partial batch, we pad the last batch for train.
batches = int(math.ceil(data_size / batch_size_global))
else:
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
data_path = args.eval_data_path
data_size = args.eval_data_size
batch_size_local = args.eval_batch_size
batch_size_global = args.eval_batch_size * args.gpus
# Since webdataset disallows partial batch, we drop the last batch for eval.
batches = int(data_size / batch_size_global)
dataset = wds.DataPipeline(
wds.SimpleShardList(data_path),
functools.partial(wds_split, rank=rank, world_size=args.gpus),
wds.tarfile_to_samples(),
wds.decode('pil'),
wds.to_tuple('jpg;png;jpeg cls'),
wds.map_tuple(transform, identity),
wds.batched(batch_size_local, partial=False),
)
num_workers = args.dataloader_num_workers
dataloader = wds.WebLoader(
dataset=dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
pin_memory=True).repeat(nbatches=batches)
print(f'{mode} dataloader | samples: {data_size}, '
f'num_workers: {num_workers}, '
f'local batch size: {batch_size_local}, '
f'global batch size: {batch_size_global}, '
f'batches: {batches}')
return dataloader