in distributed_training/train_code/pytorch_mnist_smdp.py [0:0]
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0 and args.rank == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data) * args.world_size,
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
if args.verbose:
print("Batch", batch_idx, "from rank", args.rank)