def train_one_epoch()

in torchbenchmark/util/framework/timm/train.py [0:0]


def train_one_epoch(
        epoch, model, loader, optimizer, loss_fn, args,
        lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
        loss_scaler=None, model_ema=None, mixup_fn=None):

    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
        if args.prefetcher and loader.mixup_enabled:
            loader.mixup_enabled = False
        elif mixup_fn is not None:
            mixup_fn.mixup_enabled = False

    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
    # batch_time_m = AverageMeter()
    # data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    # end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in zip(range(args.train_num_batch), loader):
        last_batch = batch_idx == last_idx
        # data_time_m.update(time.time() - end)
        if not args.prefetcher and args.device == "cuda":
            input, target = input.cuda(), target.cuda()
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        with amp_autocast():
            output = model(input)
            loss = loss_fn(output, target)

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if loss_scaler is not None:
            loss_scaler(
                loss, optimizer,
                clip_grad=args.clip_grad, clip_mode=args.clip_mode,
                parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
                create_graph=second_order)
        else:
            loss.backward(create_graph=second_order)
            if args.clip_grad is not None:
                dispatch_clip_grad(
                    model_parameters(model, exclude_head='agc' in args.clip_mode),
                    value=args.clip_grad, mode=args.clip_mode)
            optimizer.step()

        # if model_ema is not None:
        #     model_ema.update(model)
        if args.device == "cuda":
            torch.cuda.synchronize()
        num_updates += 1
        # batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            # if args.local_rank == 0:
            #     _logger.info(
            #         'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
            #         'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '
            #         'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
            #         '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
            #         'LR: {lr:.3e}  '
            #         'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
            #             epoch,
            #             batch_idx, len(loader),
            #             100. * batch_idx / last_idx,
            #             loss=losses_m,
            #             batch_time=batch_time_m,
            #             rate=input.size(0) * args.world_size / batch_time_m.val,
            #             rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
            #             lr=lr,
            #             data_time=data_time_m))

                # if args.save_images and output_dir:
                #     torchvision.utils.save_image(
                #         input,
                #         os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                #         padding=0,
                #         normalize=True)

        # if saver is not None and args.recovery_interval and (
        #         last_batch or (batch_idx + 1) % args.recovery_interval == 0):
        #     saver.save_recovery(epoch, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

        # end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()