in tensorflow_gan/python/eval/sliced_wasserstein.py [0:0]
def sliced_wasserstein_distance(real_images,
fake_images,
resolution_min=16,
patches_per_image=64,
patch_size=7,
random_sampling_count=1,
random_projection_dim=7 * 7 * 3,
use_svd=False):
"""Compute the Wasserstein distance between two distributions of images.
Note that measure vary with the number of images. Use 8192 images to get
numbers comparable to the ones in the original paper.
Args:
real_images: (tensor) Real images (batch, height, width, channels).
fake_images: (tensor) Fake images (batch, height, width, channels).
resolution_min: (int) Minimum resolution for the Laplacian pyramid.
patches_per_image: (int) Number of patches to extract per image per
Laplacian level.
patch_size: (int) Width of a square patch.
random_sampling_count: (int) Number of random projections to average.
random_projection_dim: (int) Dimension of the random projection space.
use_svd: experimental method to compute a more accurate distance.
Returns:
List of tuples (distance_real, distance_fake) for each level of the
Laplacian pyramid from the highest resolution to the lowest.
distance_real is the Wasserstein distance between real images
distance_fake is the Wasserstein distance between real and fake images.
Raises:
ValueError: If the inputs shapes are incorrect. Input tensor dimensions
(batch, height, width, channels) are expected to be known at graph
construction time. In addition height and width must be the same and the
number of colors should be exactly 3. Real and fake images must have the
same size.
"""
height = real_images.shape[1]
real_images.shape.assert_is_compatible_with([None, None, height, 3])
fake_images.shape.assert_is_compatible_with(real_images.shape)
# Select resolutions.
resolution_full = int(height)
resolution_min = min(resolution_min, resolution_full)
resolution_max = resolution_full
# Base loss of detail.
resolutions = [
2**i for i in range(
int(np.log2(resolution_max)),
int(np.log2(resolution_min)) - 1, -1)
]
# Gather patches for each level of the Laplacian pyramids.
patches_real, patches_fake, patches_test = (
[[] for _ in resolutions] for _ in range(3))
for lod, level in enumerate(laplacian_pyramid(real_images, len(resolutions))):
patches_real[lod].append(
_batch_to_patches(level, patches_per_image, patch_size))
patches_test[lod].append(
_batch_to_patches(level, patches_per_image, patch_size))
for lod, level in enumerate(laplacian_pyramid(fake_images, len(resolutions))):
patches_fake[lod].append(
_batch_to_patches(level, patches_per_image, patch_size))
for lod in range(len(resolutions)):
for patches in [patches_real, patches_test, patches_fake]:
patches[lod] = _normalize_patches(patches[lod])
# Evaluate scores.
scores = []
for lod in range(len(resolutions)):
if not use_svd:
scores.append(
(_sliced_wasserstein(patches_real[lod], patches_test[lod],
random_sampling_count, random_projection_dim),
_sliced_wasserstein(patches_real[lod], patches_fake[lod],
random_sampling_count, random_projection_dim)))
else:
scores.append(
(_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]),
_sliced_wasserstein_svd(patches_real[lod], patches_fake[lod])))
return scores