in scripts/train_imagenet.py [0:0]
def train(train_loader, model, criterion, optimizer, scheduler, epoch):
global logger, conf, tb
batch_time = utils.AverageMeter()
data_time = utils.AverageMeter()
losses = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
if conf["optimizer"]["schedule"]["mode"] == "epoch":
scheduler.step(epoch)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
if conf["optimizer"]["schedule"]["mode"] == "step":
scheduler.step(i + epoch * len(train_loader))
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
if conf["optimizer"]["clip"] != 0.0:
nn.utils.clip_grad_norm(model.parameters(), conf["optimizer"]["clip"])
optimizer.step()
# measure accuracy and record loss
with torch.no_grad():
output = output.detach()
loss = loss.detach() * target.shape[0]
prec1, prec5 = utils.accuracy_sum(output, target, topk=(1, 5))
count = target.new_tensor([target.shape[0]], dtype=torch.long)
if dist.is_initialized():
dist.all_reduce(count, dist.ReduceOp.SUM)
for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
if dist.is_initialized():
dist.all_reduce(val, dist.ReduceOp.SUM)
val /= count.item()
meter.update(val.item(), count.item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
logger.info(
"Epoch: [{0}][{1}/{2}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t"
"Data {data_time.val:.3f} ({data_time.avg:.3f}) \t"
"Loss {loss.val:.4f} ({loss.avg:.4f}) \t"
"Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t"
"Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
epoch,
i,
len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
top1=top1,
top5=top5,
)
)
if not dist.is_initialized() or dist.get_rank() == 0:
tb.add_scalar("train/loss", losses.val, i + epoch * len(train_loader))
tb.add_scalar(
"train/lr", scheduler.get_lr()[0], i + epoch * len(train_loader)
)
tb.add_scalar("train/top1", top1.val, i + epoch * len(train_loader))
tb.add_scalar("train/top5", top5.val, i + epoch * len(train_loader))
if args.log_hist and i % 10 == 0:
for name, param in model.named_parameters():
if name.find("fc") != -1 or name.find("bn_out") != -1:
tb.add_histogram(
name,
param.clone().cpu().data.numpy(),
i + epoch * len(train_loader),
)