def main_training_loop()

in main.py [0:0]


def main_training_loop(train_loader, val_loader, model, loss_fn, start_epoch, stop_epoch, params):
    # init timing
    data_time = 0
    sgd_time = 0
    test_time = 0


    optimizer = torch.optim.SGD(model.parameters(), params.lr, momentum=params.momentum, weight_decay=params.weight_decay, dampening=params.dampening)
    for epoch in range(start_epoch,stop_epoch):
        adjust_learning_rate(optimizer, epoch, params)
        model.train()


        # start timing
        data_time=0
        sgd_time=0
        test_time=0
        start_data_time=time.time()
        avg_loss=0

        #train
        for i, (x,y) in enumerate(train_loader):
            data_time = data_time + (time.time()-start_data_time)
            x = x.cuda()
            y = y.cuda()
            start_sgd_time=time.time()
            optimizer.zero_grad()
            x_var = Variable(x)
            y_var = Variable(y)

            loss = loss_fn(model, x_var, y_var)
            loss.backward()
            optimizer.step()
            sgd_time = sgd_time + (time.time()-start_sgd_time)

            avg_loss = avg_loss+loss.data[0]

            if i % params.print_freq==0:
                print(optimizer.state_dict()['param_groups'][0]['lr'])
                print('Epoch {:d}/{:d} | Batch {:d}/{:d} | Loss {:f} | Data time {:f} | SGD time {:f}'.format(epoch,
                    stop_epoch, i, len(train_loader), avg_loss/float(i+1), data_time/float(i+1), sgd_time/float(i+1)))
            start_data_time = time.time()


        #test
        model.eval()
        data_time=0
        start_data_time = time.time()
        top1=0
        top5=0
        count = 0
        for i, (x,y) in enumerate(val_loader):
            data_time = data_time + (time.time()-start_data_time)
            x = x.cuda()
            y = y.cuda()
            start_test_time = time.time()
            x_var = Variable(x)
            scores = model(x_var)[0]
            top1_this, top5_this = accuracy(scores.data, y)
            top1 = top1+top1_this
            top5 = top5+top5_this
            count = count+scores.size(0)
            test_time = test_time + time.time()-start_test_time
            if (i%params.print_freq==0) or (i==len(val_loader)-1):
                print('Epoch {:d}/{:d} | Batch {:d}/{:d} | Top-1 {:f} | Top-5 {:f} | Data time {:f} | Test time {:f}'.format(epoch,
                    stop_epoch, i, len(val_loader), top1/float(count), top5/float(count), data_time/float(i+1), test_time/float(i+1)))



        if (epoch % params.save_freq==0) or (epoch==stop_epoch-1):
            if not os.path.isdir(params.checkpoint_dir):
                os.makedirs(params.checkpoint_dir)
            outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
            torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)

    return model