def train()

in build_graph/localization_network/train.py [0:0]


def train(iteration, trainloader, valloader, net, optimizer, writer):

    net.train()

    total_iters = len(trainloader)
    epoch = iteration//total_iters
    plot_every = int(0.1*len(trainloader))
    loss_meters = collections.defaultdict(lambda: tnt.meter.MovingAverageValueMeter(20))

    while iteration <= args.max_iter:

        for batch in trainloader:

            batch = util.batch_cuda(batch)
            pred, loss_dict = net(batch)

            loss_dict = {k:v.mean() for k,v in loss_dict.items()}
            loss = sum(loss_dict.values())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred_idx = pred.max(1)
            correct = (pred_idx==batch['label']).float().sum()
            batch_acc = correct/pred.shape[0]
            loss_meters['bAcc'].add(batch_acc.item())

            for k, v in loss_dict.items():
                loss_meters[k].add(v.item())
            loss_meters['total_loss'].add(loss.item())

            if iteration%args.print_every==0:
                log_str = 'iter: %d (%d + %d/%d) | '%(iteration, epoch, iteration%total_iters, total_iters)
                log_str += ' | '.join(['%s: %.3f'%(k, v.value()[0]) for k,v in loss_meters.items()])
                print (log_str)

            if iteration%plot_every==0:
                for key in loss_meters:
                    writer.add_scalar('train/%s'%key, loss_meters[key].value()[0], int(100*iteration/total_iters))

            iteration += 1
        
        epoch += 1

        if epoch%10==0:
            with torch.no_grad():
                validate(epoch, iteration, valloader, net, optimizer, writer)