def train()

in data_augmentation/my_training.py [0:0]


def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.4f")
    top5 = AverageMeter("Acc@5", ":6.4f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch),
    )

    # switch to train mode
    model.train()
    end = time.time()
    epoch_start = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        # compute output
        if args.augerino and args.inv_per_class:
            output = model(images, target)
        else:
            output = model(images)
        if args.augerino:
            loss = criterion(output, target, model, args, reg=args.aug_reg)
        else:
            loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        n_points = torch.FloatTensor([images.size(0)])
        if torch.cuda.is_available():
            n_points = n_points.cuda(non_blocking=True)
        
        if args.distributed:
            torch.distributed.all_reduce(acc1)
            torch.distributed.all_reduce(acc5)
            torch.distributed.all_reduce(n_points)
        
        acc1 = acc1 / n_points * 100.0
        acc5 = acc5 / n_points * 100.0

        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], n_points.item())
        top5.update(acc5[0], n_points.item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    epoch_end = time.time()
    print("time 1 epoch {:.3f}".format(epoch_end - epoch_start))