def worker()

in community-content/pytorch_efficient_training/resnet_fsdp.py [0:0]


def worker(gpu, args):
  """Run training and evaluation."""
  # Init process group.
  print(f'Initiating process {gpu}')
  dist.init_process_group(
      backend='nccl',
      init_method='env://',
      world_size=args.gpus,
      rank=gpu)

  # Create train dataloader.
  train_dataset = ImageFolder(
      image_list_file=args.train_data_path,
      transform=torchvision.transforms.Compose([
          torchvision.transforms.RandomResizedCrop(224),
          torchvision.transforms.RandomHorizontalFlip(),
          torchvision.transforms.ToTensor(),
          torchvision.transforms.Normalize(
              mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ]))
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset, num_replicas=args.gpus, rank=gpu)
  train_dataloader = torch.utils.data.DataLoader(
      dataset=train_dataset,
      batch_size=args.train_batch_size,
      shuffle=False,
      num_workers=args.dataloader_num_workers,
      pin_memory=True,
      sampler=train_sampler)
  if gpu == 0:
    print(f'Train dataloader | samples: {len(train_dataloader.dataset)}, '
          f'num workers: {train_dataloader.num_workers}, '
          f'global batch size: {args.train_batch_size * args.gpus}, '
          f'batches/epoch: {len(train_dataloader)}')

  # Create eval dataloader.
  eval_dataset = ImageFolder(
      image_list_file=args.eval_data_path,
      transform=torchvision.transforms.Compose([
          torchvision.transforms.Resize(256),
          torchvision.transforms.CenterCrop(224),
          torchvision.transforms.ToTensor(),
          torchvision.transforms.Normalize(
              mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ]))
  eval_sampler = torch.utils.data.distributed.DistributedSampler(
      eval_dataset, num_replicas=args.gpus, rank=gpu)
  eval_dataloader = torch.utils.data.DataLoader(
      dataset=eval_dataset,
      batch_size=args.eval_batch_size,
      shuffle=False,
      num_workers=args.dataloader_num_workers,
      pin_memory=True,
      drop_last=True,
      sampler=eval_sampler)
  if gpu == 0:
    print(f'Eval dataloader | samples: {len(eval_dataloader.dataset)}, '
          f'num workers: {eval_dataloader.num_workers}, '
          f'batch size: {args.eval_batch_size}, '
          f'batches/epoch: {len(eval_dataloader)}')

  # Wrap policy.
  my_auto_wrap_policy = functools.partial(
      size_based_auto_wrap_policy, min_num_params=100)
  torch.cuda.set_device(gpu)

  # Create model.
  model = resnet50(weights=None)
  model.to(args.device)
  model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)

  # Optimizer.
  optimizer = torch.optim.SGD(model.parameters(), 0.1)

  # Main loop.
  metric = torchmetrics.classification.Accuracy(top_k=1).to(args.device)
  for epoch in range(1, args.epochs + 1):
    if gpu == 0:
      print(f'Running epoch {epoch}')
    train_sampler.set_epoch(epoch)

    start = time.time()
    train(model, args.device, train_dataloader, optimizer)
    end = time.time()
    if gpu == 0:
      print(f'Training finished in {(end - start):>0.3f} seconds')

    start = time.time()
    evaluate(model, args.device, eval_dataloader, metric)
    end = time.time()
    if gpu == 0:
      print(f'Evaluation finished in {(end - start):>0.3f} seconds')

  if gpu == 0:
    print('Done')
  dist.destroy_process_group()