in community-content/pytorch_efficient_training/resnet_ddp.py [0:0]
def worker(gpu, args):
"""Run training and evaluation."""
# Init process group.
print(f'Initiating process {gpu}')
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.gpus,
rank=gpu)
# Create model.
model = resnet50(weights=None)
torch.cuda.set_device(gpu)
model.to(args.device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# Create train dataloader.
train_dataset = ImageFolder(
image_list_file=args.train_data_path,
transform=torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]))
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.gpus, rank=gpu)
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.train_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
pin_memory=True,
sampler=train_sampler)
if gpu == 0:
print(f'Train dataloader | samples: {len(train_dataloader.dataset)}, '
f'num workers: {train_dataloader.num_workers}, '
f'global batch size: {args.train_batch_size * args.gpus}, '
f'batches/epoch: {len(train_dataloader)}')
# Create eval dataloader.
eval_dataset = ImageFolder(
image_list_file=args.eval_data_path,
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]))
eval_sampler = torch.utils.data.distributed.DistributedSampler(
eval_dataset, num_replicas=args.gpus, rank=gpu)
eval_dataloader = torch.utils.data.DataLoader(
dataset=eval_dataset,
batch_size=args.eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
pin_memory=True,
drop_last=True,
sampler=eval_sampler)
if gpu == 0:
print(f'Eval dataloader | samples: {len(eval_dataloader.dataset)}, '
f'num workers: {eval_dataloader.num_workers}, '
f'batch size: {args.eval_batch_size}, '
f'batches/epoch: {len(eval_dataloader)}')
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), 0.1)
# Main loop.
metric = torchmetrics.classification.Accuracy(top_k=1).to(args.device)
for epoch in range(1, args.epochs + 1):
if gpu == 0:
print(f'Running epoch {epoch}')
train_sampler.set_epoch(epoch)
start = time.time()
train(model, args.device, train_dataloader, optimizer)
end = time.time()
if gpu == 0:
print(f'Training finished in {(end - start):>0.3f} seconds')
start = time.time()
evaluate(model, args.device, eval_dataloader, metric)
end = time.time()
if gpu == 0:
print(f'Evaluation finished in {(end - start):>0.3f} seconds')
if gpu == 0:
print('Done')