def run_training()

in community-content/pytorch_efficient_training/resnet.py [0:0]


def run_training(args):
  """Run training and evaluation."""
  # Create model.
  model = resnet50(weights=None)
  model = model.to(args.device)

  # Create train dataloader.
  train_dataset = ImageFolder(
      image_list_file=args.train_data_path,
      transform=torchvision.transforms.Compose([
          torchvision.transforms.RandomResizedCrop(224),
          torchvision.transforms.RandomHorizontalFlip(),
          torchvision.transforms.ToTensor(),
          torchvision.transforms.Normalize(
              mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ]))
  train_dataloader = torch.utils.data.DataLoader(
      dataset=train_dataset,
      batch_size=args.train_batch_size,
      shuffle=True,
      num_workers=args.dataloader_num_workers,
      pin_memory=True)
  print(f'Train dataloader | samples: {len(train_dataloader.dataset)}, '
        f'num workers: {train_dataloader.num_workers}, '
        f'batch size: {args.train_batch_size}, '
        f'batches/epoch: {len(train_dataloader)}')

  # Create eval dataloader.
  eval_dataset = ImageFolder(
      image_list_file=args.eval_data_path,
      transform=torchvision.transforms.Compose([
          torchvision.transforms.Resize(256),
          torchvision.transforms.CenterCrop(224),
          torchvision.transforms.ToTensor(),
          torchvision.transforms.Normalize(
              mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ]))
  eval_dataloader = torch.utils.data.DataLoader(
      dataset=eval_dataset,
      batch_size=args.eval_batch_size,
      shuffle=False,
      num_workers=args.dataloader_num_workers,
      pin_memory=True,
      drop_last=True)
  print(f'Eval dataloader | samples: {len(eval_dataloader.dataset)}, '
        f'num workers: {eval_dataloader.num_workers}, '
        f'batch size: {args.eval_batch_size}, '
        f'batches/epoch: {len(eval_dataloader)}')

  # Optimizer.
  optimizer = torch.optim.SGD(model.parameters(), 0.1)

  # Main loop.
  metric = torchmetrics.classification.Accuracy(top_k=1).to(args.device)
  for epoch in range(1, args.epochs + 1):
    print(f'Running epoch {epoch}')

    start = time.time()
    train(model, args.device, train_dataloader, optimizer)
    end = time.time()
    print(f'Training finished in {(end - start):>0.3f} seconds')

    start = time.time()
    evaluate(model, args.device, eval_dataloader, metric)
    end = time.time()
    print(f'Evaluation finished in {(end - start):>0.3f} seconds')
  print('Done')