def main()

in community-content/pytorch_image_classification_single_gpu_with_vertex_sdk_and_torchserve/trainer/task.py [0:0]


def main():

  args = parse_args()

  local_data_dir = './tmp/data'

  local_model_dir = './tmp/model'
  local_tensorboard_log_dir = './tmp/logs'

  model_dir = args.model_dir or local_model_dir
  tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir

  gs_prefix = 'gs://'
  gcsfuse_prefix = '/gcs/'
  if model_dir and model_dir.startswith(gs_prefix):
    model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
  if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
    tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)

  makedirs(model_dir)
  writer = SummaryWriter(tensorboard_log_dir)

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(f'Device: {device}')

  data_dir = download_data(local_data_dir)
  image_datasets, class_names = load_dataset(data_dir)

  dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
  print(f'Dataset sizes: {dataset_sizes}')

  dataloaders = {
      x: torch.utils.data.DataLoader(
          image_datasets[x],
          batch_size=args.batch_size,
          shuffle=True,
          num_workers=args.num_workers
      )
      for x in ['train', 'val']
  }

  model_ft = load_model(class_names, device)

  criterion = torch.nn.CrossEntropyLoss()

  # Observe that all parameters are being optimized
  optimizer_ft = optim.SGD(
      model_ft.parameters(), lr=args.learning_rate, momentum=args.momentum)

  # Decay LR by a factor of 0.1 every 7 epochs
  exp_lr_scheduler = optim.lr_scheduler.StepLR(
      optimizer_ft, step_size=7, gamma=0.1)

  model = train(
      model=model_ft,
      criterion=criterion,
      optimizer=optimizer_ft,
      scheduler=exp_lr_scheduler,
      dataset_sizes=dataset_sizes,
      dataloaders=dataloaders,
      device=device,
      epochs=args.epochs,
      writer=writer,
  )
  model_name = 'antandbee.pth'
  model_path = os.path.join(model_dir, f'{model_name}')

  torch.save(model.state_dict(), model_path)
  print(f'Model is saved to {model_dir}')

  print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')

  writer.close()

  return