in svhn_based_exp/utils.py [0:0]
def train_epoch(optimizer, criterion, trainloader, net, logger, regression, device):
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
if regression:
targets = targets.float()
loss = criterion(outputs[:, 0], targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
else:
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if regression:
logger.info('==>>> train loss: {:.6f}'.format(train_loss/(batch_idx+1)))
else:
logger.info('==>>> train loss: {:.6f}, accuracy: {:.4f}'.format(train_loss/(batch_idx+1), 100.*correct/total))