def main()

in community-content/pytorch_image_classification_distributed_data_parallel_training_with_vertex_sdk/trainer/task.py [0:0]


def main():

  args = parse_args()

  local_data_dir = './tmp/data'
  local_model_dir = './tmp/model'
  local_tensorboard_log_dir = './tmp/logs'
  local_checkpoint_dir = './tmp/checkpoints'

  model_dir = args.model_dir or local_model_dir
  tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
  checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir

  gs_prefix = 'gs://'
  gcsfuse_prefix = '/gcs/'
  if model_dir and model_dir.startswith(gs_prefix):
    model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
  if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
    tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
  if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
    checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)

  writer = SummaryWriter(tensorboard_log_dir)

  is_chief = args.rank == 0
  if is_chief:
    makedirs(checkpoint_dir)
    print(f'Checkpoints will be saved to {checkpoint_dir}')

  checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pt')
  print(f'checkpoint_path is {checkpoint_path}')

  if args.world_size > 1:
    print('Initializing distributed backend with {} nodes'.format(args.world_size))
    distributed.init_process_group(
          backend=args.backend,
          init_method=args.init_method,
          world_size=args.world_size,
          rank=args.rank,
      )
    print(f'[{os.getpid()}]: '
          f'world_size = {distributed.get_world_size()}, '
          f'rank = {distributed.get_rank()}, '
          f'backend={distributed.get_backend()} \n', end='')

  if torch.cuda.is_available() and not args.no_cuda:
    device = torch.device('cuda:{}'.format(args.rank))
  else:
    device = torch.device('cpu')

  model = Net(device=device)
  if distributed_is_initialized():
    model.to(device)
    model = DistributedDataParallel(model)

  if is_chief:
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in one process is sufficient.
    torch.save(model.state_dict(), checkpoint_path)
    print(f'Initial chief checkpoint is saved to {checkpoint_path}')

  # Use a barrier() to make sure that process 1 loads the model after process
  # 0 saves it.
  if distributed_is_initialized():
    distributed.barrier()
    # configure map_location properly
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f'Initial chief checkpoint is saved to {checkpoint_path} with map_location {device}')
  else:
    model.load_state_dict(torch.load(checkpoint_path))
    print(f'Initial chief checkpoint is loaded from {checkpoint_path}')

  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

  train_loader = MNISTDataLoader(
      local_data_dir, args.batch_size, train=True)
  test_loader = MNISTDataLoader(
      local_data_dir, args.batch_size, train=False)

  trainer = Trainer(
      model=model,
      optimizer=optimizer,
      train_loader=train_loader,
      test_loader=test_loader,
      device=device,
      model_name='mnist.pt',
      checkpoint_path=checkpoint_path,
  )
  trainer.fit(args.epochs, is_chief, writer)

  if model_dir == local_model_dir:
    makedirs(model_dir)
    trainer.save(model_dir)
    print(f'Model is saved to {model_dir}')

  print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')

  writer.close()

  if is_chief:
    os.remove(checkpoint_path)

  if distributed_is_initialized():
    distributed.destroy_process_group()

  return