def create_wds_dataloader()

in community-content/pytorch_efficient_training/resnet_ddp_wds.py [0:0]


def create_wds_dataloader(rank, args, mode):
  """Create webdataset dataset and dataloader."""
  if mode == 'train':
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    data_path = args.train_data_path
    data_size = args.train_data_size
    batch_size_local = args.train_batch_size
    batch_size_global = args.train_batch_size * args.gpus
    # Since webdataset disallows partial batch, we pad the last batch for train.
    batches = int(math.ceil(data_size / batch_size_global))
  else:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    data_path = args.eval_data_path
    data_size = args.eval_data_size
    batch_size_local = args.eval_batch_size
    batch_size_global = args.eval_batch_size * args.gpus
    # Since webdataset disallows partial batch, we drop the last batch for eval.
    batches = int(data_size / batch_size_global)

  dataset = wds.DataPipeline(
      wds.SimpleShardList(data_path),
      functools.partial(wds_split, rank=rank, world_size=args.gpus),
      wds.tarfile_to_samples(),
      wds.decode('pil'),
      wds.to_tuple('jpg;png;jpeg cls'),
      wds.map_tuple(transform, identity),
      wds.batched(batch_size_local, partial=False),
  )
  num_workers = args.dataloader_num_workers
  dataloader = wds.WebLoader(
      dataset=dataset,
      batch_size=None,
      shuffle=False,
      num_workers=num_workers,
      persistent_workers=True if num_workers > 0 else False,
      pin_memory=True).repeat(nbatches=batches)
  print(f'{mode} dataloader | samples: {data_size}, '
        f'num_workers: {num_workers}, '
        f'local batch size: {batch_size_local}, '
        f'global batch size: {batch_size_global}, '
        f'batches: {batches}')
  return dataloader