in torchbenchmark/util/framework/timm/train.py [0:0]
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
# batch_time_m = AverageMeter()
# data_time_m = AverageMeter()
losses_m = AverageMeter()
model.train()
# end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in zip(range(args.train_num_batch), loader):
last_batch = batch_idx == last_idx
# data_time_m.update(time.time() - end)
if not args.prefetcher and args.device == "cuda":
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
# if model_ema is not None:
# model_ema.update(model)
if args.device == "cuda":
torch.cuda.synchronize()
num_updates += 1
# batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
# if args.local_rank == 0:
# _logger.info(
# 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
# 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
# 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
# '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
# 'LR: {lr:.3e} '
# 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
# epoch,
# batch_idx, len(loader),
# 100. * batch_idx / last_idx,
# loss=losses_m,
# batch_time=batch_time_m,
# rate=input.size(0) * args.world_size / batch_time_m.val,
# rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
# lr=lr,
# data_time=data_time_m))
# if args.save_images and output_dir:
# torchvision.utils.save_image(
# input,
# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
# padding=0,
# normalize=True)
# if saver is not None and args.recovery_interval and (
# last_batch or (batch_idx + 1) % args.recovery_interval == 0):
# saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
# end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()