def train_mnist_classifier()

in gan_eval_metrics.py [0:0]


def train_mnist_classifier(lr=0.001, epochs=50, model_dir='.'):
    """train mnist classifier for inception score"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device {0!s}".format(device))

    train_loader = load_mnist(batchSize=100, train=True)
    test_loader = load_mnist(batchSize=100, train=False)

    model = LeNet().to(device)

    def evaluate():
        model.eval()
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        accuracy = 100. * correct / len(test_loader.dataset)
        return accuracy

    train_criterion = torch.nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # training loop
    print('Started training...')
    best_test_acc = 0.0
    best_test_epoch = 0
    for epoch in range(1, epochs + 1):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data).squeeze(1)
            loss = train_criterion(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 20 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))

        test_acc = evaluate()
        print('Test Accuracy: {:.2f}\n'.format(test_acc))
        if test_acc > best_test_acc:
            best_test_epoch = epoch
            best_test_acc = test_acc
            torch.save(model.state_dict(), os.path.join(model_dir, "mnist_classifier.pt"))

    print('Finished.')
    print('Best: Epoch: {}, Test-Accuracy: {:.4f}\n'.format(best_test_epoch, best_test_acc))