def train_epoch()

in svhn_based_exp/utils.py [0:0]


def train_epoch(optimizer, criterion, trainloader, net, logger, regression, device):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        if regression:
            targets = targets.float()
            loss = criterion(outputs[:, 0], targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        else:
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    if regression:
        logger.info('==>>> train loss: {:.6f}'.format(train_loss/(batch_idx+1)))
    else:
        logger.info('==>>> train loss: {:.6f}, accuracy: {:.4f}'.format(train_loss/(batch_idx+1), 100.*correct/total))