in build_graph/localization_network/train.py [0:0]
def train(iteration, trainloader, valloader, net, optimizer, writer):
net.train()
total_iters = len(trainloader)
epoch = iteration//total_iters
plot_every = int(0.1*len(trainloader))
loss_meters = collections.defaultdict(lambda: tnt.meter.MovingAverageValueMeter(20))
while iteration <= args.max_iter:
for batch in trainloader:
batch = util.batch_cuda(batch)
pred, loss_dict = net(batch)
loss_dict = {k:v.mean() for k,v in loss_dict.items()}
loss = sum(loss_dict.values())
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, pred_idx = pred.max(1)
correct = (pred_idx==batch['label']).float().sum()
batch_acc = correct/pred.shape[0]
loss_meters['bAcc'].add(batch_acc.item())
for k, v in loss_dict.items():
loss_meters[k].add(v.item())
loss_meters['total_loss'].add(loss.item())
if iteration%args.print_every==0:
log_str = 'iter: %d (%d + %d/%d) | '%(iteration, epoch, iteration%total_iters, total_iters)
log_str += ' | '.join(['%s: %.3f'%(k, v.value()[0]) for k,v in loss_meters.items()])
print (log_str)
if iteration%plot_every==0:
for key in loss_meters:
writer.add_scalar('train/%s'%key, loss_meters[key].value()[0], int(100*iteration/total_iters))
iteration += 1
epoch += 1
if epoch%10==0:
with torch.no_grad():
validate(epoch, iteration, valloader, net, optimizer, writer)