in bring-your-own-container/pytorch_extending_our_containers/container/cifar10/cifar10.py [0:0]
def _train(args):
is_distributed = len(args.hosts) > 1 and args.dist_backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ["WORLD_SIZE"] = str(world_size)
host_rank = args.hosts.index(args.current_host)
dist.init_process_group(backend=args.dist_backend, rank=host_rank, world_size=world_size)
logger.info(
"Initialized the distributed environment: '{}' backend on {} nodes. ".format(
args.dist_backend, dist.get_world_size()
)
+ "Current host rank is {}. Using cuda: {}. Number of gpus: {}".format(
dist.get_rank(), torch.cuda.is_available(), args.num_gpus
)
)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Device Type: {}".format(device))
logger.info("Loading Cifar10 dataset")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.CIFAR10(
root=args.data_dir, train=True, download=False, transform=transform
)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers
)
testset = torchvision.datasets.CIFAR10(
root=args.data_dir, train=False, download=False, transform=transform
)
test_loader = torch.utils.data.DataLoader(
testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers
)
logger.info("Model loaded")
model = Net()
if torch.cuda.device_count() > 1:
logger.info("Gpu count: {}".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(0, args.epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
# get the inputs
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print("Finished Training")
return _save_model(model, args.model_dir)