in blog/pytorch_cnn_cifar10/source/cifar10.py [0:0]
def _train(args):
is_distributed = len(args.hosts) > 1 and args.dist_backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
os.environ['RANK'] = str(host_rank)
dist.init_process_group(backend=args.dist_backend, rank=host_rank, world_size=world_size)
logger.info(
'Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
args.dist_backend,
dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(
dist.get_rank(), torch.cuda.is_available(), args.num_gpus))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info("Device Type: {}".format(device))
logger.info("Loading Cifar10 dataset")
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False,
download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)
logger.info("Model loaded")
model = Net()
if torch.cuda.device_count() > 1:
logger.info("Gpu count: {}".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(0, args.epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
# get the inputs
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
return _save_model(model, args.model_dir)