def _frechet_classifier_distance_from_activations_helper()

in tensorflow_gan/python/eval/classifier_metrics.py [0:0]


def _frechet_classifier_distance_from_activations_helper(
    activations1, activations2, streaming=False):
  """A helper function evaluating the frechet classifier distance."""
  activations1 = tf.convert_to_tensor(value=activations1)
  activations1.shape.assert_has_rank(2)
  activations2 = tf.convert_to_tensor(value=activations2)
  activations2.shape.assert_has_rank(2)

  activations_dtype = activations1.dtype
  if activations_dtype != tf.float64:
    activations1 = tf.cast(activations1, tf.float64)
    activations2 = tf.cast(activations2, tf.float64)

  # Compute mean and covariance matrices of activations.
  if streaming:
    m = eval_utils.streaming_mean_tensor_float64(
        tf.reduce_mean(input_tensor=activations1, axis=0))
    m_w = eval_utils.streaming_mean_tensor_float64(
        tf.reduce_mean(input_tensor=activations2, axis=0))
    sigma = eval_utils.streaming_covariance(activations1)
    sigma_w = eval_utils.streaming_covariance(activations2)
  else:
    m = (tf.reduce_mean(input_tensor=activations1, axis=0),)
    m_w = (tf.reduce_mean(input_tensor=activations2, axis=0),)
    # Calculate the unbiased covariance matrix of first activations.
    num_examples_real = tf.cast(tf.shape(input=activations1)[0], tf.float64)
    sigma = (num_examples_real / (num_examples_real - 1) *
             tfp.stats.covariance(activations1),)
    # Calculate the unbiased covariance matrix of second activations.
    num_examples_generated = tf.cast(
        tf.shape(input=activations2)[0], tf.float64)
    sigma_w = (num_examples_generated / (num_examples_generated - 1) *
               tfp.stats.covariance(activations2),)
  # m, m_w, sigma, sigma_w are tuples containing one or two elements: the first
  # element will be used to calculate the score value and the second will be
  # used to create the update_op. We apply the same operation on the two
  # elements to make sure their value is consistent.

  def _calculate_fid(m, m_w, sigma, sigma_w):
    """Returns the Frechet distance given the sample mean and covariance."""
    # Find the Tr(sqrt(sigma sigma_w)) component of FID
    sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)

    # Compute the two components of FID.

    # First the covariance component.
    # Here, note that trace(A + B) = trace(A) + trace(B)
    trace = tf.linalg.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component

    # Next the distance between means.
    mean = tf.reduce_sum(input_tensor=tf.math.squared_difference(
        m, m_w))  # Equivalent to L2 but more stable.
    fid = trace + mean
    if activations_dtype != tf.float64:
      fid = tf.cast(fid, activations_dtype)
    return fid

  result = tuple(
      _calculate_fid(m_val, m_w_val, sigma_val, sigma_w_val)
      for m_val, m_w_val, sigma_val, sigma_w_val in zip(m, m_w, sigma, sigma_w))
  if streaming:
    return result
  else:
    return result[0]