def sliced_wasserstein_distance()

in tensorflow_gan/python/eval/sliced_wasserstein.py [0:0]


def sliced_wasserstein_distance(real_images,
                                fake_images,
                                resolution_min=16,
                                patches_per_image=64,
                                patch_size=7,
                                random_sampling_count=1,
                                random_projection_dim=7 * 7 * 3,
                                use_svd=False):
  """Compute the Wasserstein distance between two distributions of images.

  Note that measure vary with the number of images. Use 8192 images to get
  numbers comparable to the ones in the original paper.

  Args:
      real_images: (tensor) Real images (batch, height, width, channels).
      fake_images: (tensor) Fake images (batch, height, width, channels).
      resolution_min: (int) Minimum resolution for the Laplacian pyramid.
      patches_per_image: (int) Number of patches to extract per image per
        Laplacian level.
      patch_size: (int) Width of a square patch.
      random_sampling_count: (int) Number of random projections to average.
      random_projection_dim: (int) Dimension of the random projection space.
      use_svd: experimental method to compute a more accurate distance.

  Returns:
      List of tuples (distance_real, distance_fake) for each level of the
      Laplacian pyramid from the highest resolution to the lowest.
        distance_real is the Wasserstein distance between real images
        distance_fake is the Wasserstein distance between real and fake images.
  Raises:
      ValueError: If the inputs shapes are incorrect. Input tensor dimensions
      (batch, height, width, channels) are expected to be known at graph
      construction time. In addition height and width must be the same and the
      number of colors should be exactly 3. Real and fake images must have the
      same size.
  """
  height = real_images.shape[1]
  real_images.shape.assert_is_compatible_with([None, None, height, 3])
  fake_images.shape.assert_is_compatible_with(real_images.shape)

  # Select resolutions.
  resolution_full = int(height)
  resolution_min = min(resolution_min, resolution_full)
  resolution_max = resolution_full
  # Base loss of detail.
  resolutions = [
      2**i for i in range(
          int(np.log2(resolution_max)),
          int(np.log2(resolution_min)) - 1, -1)
  ]

  # Gather patches for each level of the Laplacian pyramids.
  patches_real, patches_fake, patches_test = (
      [[] for _ in resolutions] for _ in range(3))
  for lod, level in enumerate(laplacian_pyramid(real_images, len(resolutions))):
    patches_real[lod].append(
        _batch_to_patches(level, patches_per_image, patch_size))
    patches_test[lod].append(
        _batch_to_patches(level, patches_per_image, patch_size))

  for lod, level in enumerate(laplacian_pyramid(fake_images, len(resolutions))):
    patches_fake[lod].append(
        _batch_to_patches(level, patches_per_image, patch_size))

  for lod in range(len(resolutions)):
    for patches in [patches_real, patches_test, patches_fake]:
      patches[lod] = _normalize_patches(patches[lod])

  # Evaluate scores.
  scores = []
  for lod in range(len(resolutions)):
    if not use_svd:
      scores.append(
          (_sliced_wasserstein(patches_real[lod], patches_test[lod],
                               random_sampling_count, random_projection_dim),
           _sliced_wasserstein(patches_real[lod], patches_fake[lod],
                               random_sampling_count, random_projection_dim)))
    else:
      scores.append(
          (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]),
           _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod])))
  return scores