in main.py [0:0]
def main_training_loop(train_loader, val_loader, model, loss_fn, start_epoch, stop_epoch, params):
# init timing
data_time = 0
sgd_time = 0
test_time = 0
optimizer = torch.optim.SGD(model.parameters(), params.lr, momentum=params.momentum, weight_decay=params.weight_decay, dampening=params.dampening)
for epoch in range(start_epoch,stop_epoch):
adjust_learning_rate(optimizer, epoch, params)
model.train()
# start timing
data_time=0
sgd_time=0
test_time=0
start_data_time=time.time()
avg_loss=0
#train
for i, (x,y) in enumerate(train_loader):
data_time = data_time + (time.time()-start_data_time)
x = x.cuda()
y = y.cuda()
start_sgd_time=time.time()
optimizer.zero_grad()
x_var = Variable(x)
y_var = Variable(y)
loss = loss_fn(model, x_var, y_var)
loss.backward()
optimizer.step()
sgd_time = sgd_time + (time.time()-start_sgd_time)
avg_loss = avg_loss+loss.data[0]
if i % params.print_freq==0:
print(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch {:d}/{:d} | Batch {:d}/{:d} | Loss {:f} | Data time {:f} | SGD time {:f}'.format(epoch,
stop_epoch, i, len(train_loader), avg_loss/float(i+1), data_time/float(i+1), sgd_time/float(i+1)))
start_data_time = time.time()
#test
model.eval()
data_time=0
start_data_time = time.time()
top1=0
top5=0
count = 0
for i, (x,y) in enumerate(val_loader):
data_time = data_time + (time.time()-start_data_time)
x = x.cuda()
y = y.cuda()
start_test_time = time.time()
x_var = Variable(x)
scores = model(x_var)[0]
top1_this, top5_this = accuracy(scores.data, y)
top1 = top1+top1_this
top5 = top5+top5_this
count = count+scores.size(0)
test_time = test_time + time.time()-start_test_time
if (i%params.print_freq==0) or (i==len(val_loader)-1):
print('Epoch {:d}/{:d} | Batch {:d}/{:d} | Top-1 {:f} | Top-5 {:f} | Data time {:f} | Test time {:f}'.format(epoch,
stop_epoch, i, len(val_loader), top1/float(count), top5/float(count), data_time/float(i+1), test_time/float(i+1)))
if (epoch % params.save_freq==0) or (epoch==stop_epoch-1):
if not os.path.isdir(params.checkpoint_dir):
os.makedirs(params.checkpoint_dir)
outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
return model