evaluations/th_evaluator.py (231 lines of code) (raw):
from .inception_v3 import InceptionV3
import blobfile as bf
import torch
import torch.distributed as dist
import torch.nn as nn
from cm import dist_util
import numpy as np
import warnings
from scipy import linalg
from PIL import Image
from tqdm import tqdm
def clip_preproc(preproc_fn, x):
return preproc_fn(Image.fromarray(x.astype(np.uint8)))
def all_gather(x, dim=0):
xs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(xs, x)
return torch.cat(xs, dim=dim)
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray, resolution: int):
self.mu = mu
self.sigma = sigma
self.resolution = resolution
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class FIDAndIS:
def __init__(
self,
softmax_batch_size=512,
clip_score_batch_size=512,
path="https://openaipublic.blob.core.windows.net/consistency/inception/inception-2015-12-05.pt",
):
import clip
super().__init__()
self.softmax_batch_size = softmax_batch_size
self.clip_score_batch_size = clip_score_batch_size
self.inception = InceptionV3()
with bf.BlobFile(path, "rb") as f:
self.inception.load_state_dict(torch.load(f))
self.inception.eval()
self.inception.to(dist_util.dev())
self.inception_softmax = self.inception.create_softmax_model()
if dist.get_rank() % 8 == 0:
clip_model, self.clip_preproc_fn = clip.load(
"ViT-B/32", device=dist_util.dev()
)
dist.barrier()
if dist.get_rank() % 8 != 0:
clip_model, self.clip_preproc_fn = clip.load(
"ViT-B/32", device=dist_util.dev()
)
dist.barrier()
# Compute the probe features separately from the final projection.
class ProjLayer(nn.Module):
def __init__(self, param):
super().__init__()
self.param = param
def forward(self, x):
return x @ self.param
self.clip_visual = clip_model.visual
self.clip_proj = ProjLayer(self.clip_visual.proj)
self.clip_visual.proj = None
class TextModel(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.clip_model = clip_model
def forward(self, x):
return self.clip_model.encode_text(x)
self.clip_tokenizer = lambda captions: clip.tokenize(captions, truncate=True)
self.clip_text = TextModel(clip_model)
self.clip_logit_scale = clip_model.logit_scale.exp().item()
self.ref_features = {}
self.is_root = not dist.is_initialized() or dist.get_rank() == 0
def get_statistics(self, activations: np.ndarray, resolution: int):
"""
Compute activation statistics for a batch of images.
:param activations: an [N x D] batch of activations.
:return: an FIDStatistics object.
"""
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma, resolution)
def get_preds(self, batch, captions=None):
with torch.no_grad():
batch = 127.5 * (batch + 1)
np_batch = batch.to(torch.uint8).cpu().numpy().transpose((0, 2, 3, 1))
pred, spatial_pred = self.inception(batch)
pred, spatial_pred = pred.reshape(
[pred.shape[0], -1]
), spatial_pred.reshape([spatial_pred.shape[0], -1])
clip_in = torch.stack(
[clip_preproc(self.clip_preproc_fn, img) for img in np_batch]
)
clip_pred = self.clip_visual(clip_in.half().to(dist_util.dev()))
if captions is not None:
text_in = self.clip_tokenizer(captions)
text_pred = self.clip_text(text_in.to(dist_util.dev()))
else:
# Hack to easily deal with no captions
text_pred = self.clip_proj(clip_pred.half())
text_pred = text_pred / text_pred.norm(dim=-1, keepdim=True)
return pred, spatial_pred, clip_pred, text_pred, np_batch
def get_inception_score(
self, activations: np.ndarray, split_size: int = 5000
) -> float:
"""
Compute the inception score using a batch of activations.
:param activations: an [N x D] batch of activations.
:param split_size: the number of samples per split. This is used to
make results consistent with other work, even when
using a different number of samples.
:return: an inception score estimate.
"""
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
with torch.no_grad():
softmax_out.append(
self.inception_softmax(torch.from_numpy(acts).to(dist_util.dev()))
.cpu()
.numpy()
)
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def get_clip_score(
self, activations: np.ndarray, text_features: np.ndarray
) -> float:
# Sizes should never mismatch, but if they do we want to compute
# _some_ value instead of crash looping.
size = min(len(activations), len(text_features))
activations = activations[:size]
text_features = text_features[:size]
scores_out = []
for i in range(0, len(activations), self.clip_score_batch_size):
acts = activations[i : i + self.clip_score_batch_size]
sub_features = text_features[i : i + self.clip_score_batch_size]
with torch.no_grad():
image_features = self.clip_proj(
torch.from_numpy(acts).half().to(dist_util.dev())
)
image_features = image_features / image_features.norm(
dim=-1, keepdim=True
)
image_features = image_features.detach().cpu().float().numpy()
scores_out.extend(np.sum(sub_features * image_features, axis=-1).tolist())
return np.mean(scores_out) * self.clip_logit_scale
def get_activations(self, data, num_samples, global_batch_size, pr_samples=50000):
if self.is_root:
preds = []
spatial_preds = []
clip_preds = []
pr_images = []
for _ in tqdm(range(0, int(np.ceil(num_samples / global_batch_size)))):
batch, cond, _ = next(data)
batch, cond = batch.to(dist_util.dev()), {
k: v.to(dist_util.dev()) for k, v in cond.items()
}
pred, spatial_pred, clip_pred, _, np_batch = self.get_preds(batch)
pred, spatial_pred, clip_pred = (
all_gather(pred).cpu().numpy(),
all_gather(spatial_pred).cpu().numpy(),
all_gather(clip_pred).cpu().numpy(),
)
if self.is_root:
preds.append(pred)
spatial_preds.append(spatial_pred)
clip_preds.append(clip_pred)
if len(pr_images) * np_batch.shape[0] < pr_samples:
pr_images.append(np_batch)
if self.is_root:
preds, spatial_preds, clip_preds, pr_images = (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
np.concatenate(clip_preds, axis=0),
np.concatenate(pr_images, axis=0),
)
# assert len(pr_images) >= pr_samples
return (
preds[:num_samples],
spatial_preds[:num_samples],
clip_preds[:num_samples],
pr_images[:pr_samples],
)
else:
return [], [], [], []
def get_virtual_batch(self, data, num_samples, global_batch_size, resolution):
preds, spatial_preds, clip_preds, batch = self.get_activations(
data, num_samples, global_batch_size, pr_samples=10000
)
if self.is_root:
fid_stats = self.get_statistics(preds, resolution)
spatial_stats = self.get_statistics(spatial_preds, resolution)
clip_stats = self.get_statistics(clip_preds, resolution)
return batch, dict(
mu=fid_stats.mu,
sigma=fid_stats.sigma,
mu_s=spatial_stats.mu,
sigma_s=spatial_stats.sigma,
mu_clip=clip_stats.mu,
sigma_clip=clip_stats.sigma,
)
else:
return None, dict()
def set_ref_batch(self, ref_batch):
with bf.BlobFile(ref_batch, "rb") as f:
data = np.load(f)
fid_stats = FIDStatistics(mu=data["mu"], sigma=data["sigma"], resolution=-1)
spatial_stats = FIDStatistics(
mu=data["mu_s"], sigma=data["sigma_s"], resolution=-1
)
clip_stats = FIDStatistics(
mu=data["mu_clip"], sigma=data["sigma_clip"], resolution=-1
)
self.ref_features[ref_batch] = (fid_stats, spatial_stats, clip_stats)
def get_ref_batch(self, ref_batch):
return self.ref_features[ref_batch]