in community-content/pytorch_image_classification_distributed_data_parallel_training_with_vertex_sdk/trainer/task.py [0:0]
def main():
args = parse_args()
local_data_dir = './tmp/data'
local_model_dir = './tmp/model'
local_tensorboard_log_dir = './tmp/logs'
local_checkpoint_dir = './tmp/checkpoints'
model_dir = args.model_dir or local_model_dir
tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir
gs_prefix = 'gs://'
gcsfuse_prefix = '/gcs/'
if model_dir and model_dir.startswith(gs_prefix):
model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)
writer = SummaryWriter(tensorboard_log_dir)
is_chief = args.rank == 0
if is_chief:
makedirs(checkpoint_dir)
print(f'Checkpoints will be saved to {checkpoint_dir}')
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pt')
print(f'checkpoint_path is {checkpoint_path}')
if args.world_size > 1:
print('Initializing distributed backend with {} nodes'.format(args.world_size))
distributed.init_process_group(
backend=args.backend,
init_method=args.init_method,
world_size=args.world_size,
rank=args.rank,
)
print(f'[{os.getpid()}]: '
f'world_size = {distributed.get_world_size()}, '
f'rank = {distributed.get_rank()}, '
f'backend={distributed.get_backend()} \n', end='')
if torch.cuda.is_available() and not args.no_cuda:
device = torch.device('cuda:{}'.format(args.rank))
else:
device = torch.device('cpu')
model = Net(device=device)
if distributed_is_initialized():
model.to(device)
model = DistributedDataParallel(model)
if is_chief:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(model.state_dict(), checkpoint_path)
print(f'Initial chief checkpoint is saved to {checkpoint_path}')
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
if distributed_is_initialized():
distributed.barrier()
# configure map_location properly
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
print(f'Initial chief checkpoint is saved to {checkpoint_path} with map_location {device}')
else:
model.load_state_dict(torch.load(checkpoint_path))
print(f'Initial chief checkpoint is loaded from {checkpoint_path}')
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
train_loader = MNISTDataLoader(
local_data_dir, args.batch_size, train=True)
test_loader = MNISTDataLoader(
local_data_dir, args.batch_size, train=False)
trainer = Trainer(
model=model,
optimizer=optimizer,
train_loader=train_loader,
test_loader=test_loader,
device=device,
model_name='mnist.pt',
checkpoint_path=checkpoint_path,
)
trainer.fit(args.epochs, is_chief, writer)
if model_dir == local_model_dir:
makedirs(model_dir)
trainer.save(model_dir)
print(f'Model is saved to {model_dir}')
print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')
writer.close()
if is_chief:
os.remove(checkpoint_path)
if distributed_is_initialized():
distributed.destroy_process_group()
return