in pytorch_managed_spot_training_checkpointing/source_dir/cifar10.py [0:0]
def _train(args):
is_distributed = len(args.hosts) > 1 and args.dist_backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
if os.path.isdir(args.checkpoint_path):
print("Checkpointing directory {} exists".format(args.checkpoint_path))
else:
print("Creating Checkpointing directory {}".format(args.checkpoint_path))
os.mkdir(args.checkpoint_path)
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
os.environ['RANK'] = str(host_rank)
dist.init_process_group(backend=args.dist_backend, rank=host_rank, world_size=world_size)
print(
'Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
args.dist_backend,
dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(
dist.get_rank(), torch.cuda.is_available(), args.num_gpus))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device Type: {}".format(device))
print("Loading Cifar10 dataset")
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False,
download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)
print("Model loaded")
model = Net()
if torch.cuda.device_count() > 1:
print("Gpu count: {}".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
# Check if checkpoints exists
if not os.path.isfile(args.checkpoint_path + '/checkpoint.pth'):
epoch_number = 0
else:
model, optimizer, epoch_number = _load_checkpoint(model, optimizer, args)
for epoch in range(epoch_number, args.epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
# get the inputs
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
_save_checkpoint(model, optimizer, epoch, loss, args)
print('Finished Training')
return _save_model(model, args.model_dir)