in gan_eval_metrics.py [0:0]
def train_mnist_classifier(lr=0.001, epochs=50, model_dir='.'):
"""train mnist classifier for inception score"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device {0!s}".format(device))
train_loader = load_mnist(batchSize=100, train=True)
test_loader = load_mnist(batchSize=100, train=False)
model = LeNet().to(device)
def evaluate():
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
train_criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# training loop
print('Started training...')
best_test_acc = 0.0
best_test_epoch = 0
for epoch in range(1, epochs + 1):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data).squeeze(1)
loss = train_criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 20 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
test_acc = evaluate()
print('Test Accuracy: {:.2f}\n'.format(test_acc))
if test_acc > best_test_acc:
best_test_epoch = epoch
best_test_acc = test_acc
torch.save(model.state_dict(), os.path.join(model_dir, "mnist_classifier.pt"))
print('Finished.')
print('Best: Epoch: {}, Test-Accuracy: {:.4f}\n'.format(best_test_epoch, best_test_acc))