torchbenchmark/models/timm_efficientdet/train.py (61 lines of code) (raw):
import torch
from collections import OrderedDict
from contextlib import suppress
from timm.utils import AverageMeter, reduce_tensor
def train_epoch(
epoch, model, loader, optimizer, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, loss_scaler=None, model_ema=None,
num_batch=1):
# batch_time_m = AverageMeter()
# data_time_m = AverageMeter()
losses_m = AverageMeter()
model.train()
# end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in zip(range(num_batch), loader):
last_batch = batch_idx == last_idx
# data_time_m.update(time.time() - end)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input, target)
loss = output['loss']
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters())
else:
loss.backward()
if args.clip_grad:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
# batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
# if args.distributed:
# reduced_loss = reduce_tensor(loss.data, args.world_size)
# losses_m.update(reduced_loss.item(), input.size(0))
#
# if args.local_rank == 0:
# logging.info(
# 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
# 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
# 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
# '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
# 'LR: {lr:.3e} '
# 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
# epoch,
# batch_idx, len(loader),
# 100. * batch_idx / last_idx,
# loss=losses_m,
# batch_time=batch_time_m,
# rate=input.size(0) * args.world_size / batch_time_m.val,
# rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
# lr=lr,
# data_time=data_time_m))
# if args.save_images and output_dir:
# torchvision.utils.save_image(
# input,
# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
# padding=0,
# normalize=True)
# if saver is not None and args.recovery_interval and (
# last_batch or (batch_idx + 1) % args.recovery_interval == 0):
# saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
# end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, args, evaluator=None, log_suffix='',
num_batch=1):
# batch_time_m = AverageMeter()
losses_m = AverageMeter()
model.eval()
# end = time.time()
# last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in zip(range(num_batch), loader):
# last_batch = batch_idx == last_idx
output = model(input, target)
loss = output['loss']
if evaluator is not None:
evaluator.add_predictions(output['detections'], target)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
# batch_time_m.update(time.time() - end)
# end = time.time()
# if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
# log_name = 'Test' + log_suffix
# logging.info(
# '{0}: [{1:>4d}/{2}] '
# 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
# 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '.format(
# log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m))
metrics = OrderedDict([('loss', losses_m.avg)])
if evaluator is not None:
metrics['map'] = evaluator.evaluate()
return metrics