def mnist_inception_score()

in gan_eval_metrics.py [0:0]


def mnist_inception_score(imgs, device=None, batch_size=500, splits=1, model_path='mnist_classifier.pt'):
    """Computes the inception score of the generated images imgs
    adapted from https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py

    imgs -- Torch tensor (bsx1x32x32) of images normalized in the range [-1, 1]
    device -- gpu/cpu
    batch_size -- batch size for feeding into mnist_classifier
    splits -- number of splits
    model_path -- path to pretrained mnist classifier
    """

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Load inception model
    mnist_classifier = LeNet().to(device)
    mnist_classifier.load_state_dict(torch.load(model_path))
    mnist_classifier.eval()

    # Get predictions
    preds = np.zeros((N, 10))

    with torch.no_grad():
        i = 0
        total_i = 0
        while total_i < N:
            batch = imgs[i*batch_size:(i+1)*batch_size]
            batch_size_i = batch.size()[0]
            batch_preds = mnist_classifier(batch)
            batch_preds = F.softmax(batch_preds).cpu().numpy()
            preds[i*batch_size:i*batch_size + batch_size_i] = batch_preds
            i += 1
            total_i += batch_size_i

    # Now compute the mean kl-div
    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)