in tensorflow_gan/python/eval/classifier_metrics.py [0:0]
def _frechet_classifier_distance_from_activations_helper(
activations1, activations2, streaming=False):
"""A helper function evaluating the frechet classifier distance."""
activations1 = tf.convert_to_tensor(value=activations1)
activations1.shape.assert_has_rank(2)
activations2 = tf.convert_to_tensor(value=activations2)
activations2.shape.assert_has_rank(2)
activations_dtype = activations1.dtype
if activations_dtype != tf.float64:
activations1 = tf.cast(activations1, tf.float64)
activations2 = tf.cast(activations2, tf.float64)
# Compute mean and covariance matrices of activations.
if streaming:
m = eval_utils.streaming_mean_tensor_float64(
tf.reduce_mean(input_tensor=activations1, axis=0))
m_w = eval_utils.streaming_mean_tensor_float64(
tf.reduce_mean(input_tensor=activations2, axis=0))
sigma = eval_utils.streaming_covariance(activations1)
sigma_w = eval_utils.streaming_covariance(activations2)
else:
m = (tf.reduce_mean(input_tensor=activations1, axis=0),)
m_w = (tf.reduce_mean(input_tensor=activations2, axis=0),)
# Calculate the unbiased covariance matrix of first activations.
num_examples_real = tf.cast(tf.shape(input=activations1)[0], tf.float64)
sigma = (num_examples_real / (num_examples_real - 1) *
tfp.stats.covariance(activations1),)
# Calculate the unbiased covariance matrix of second activations.
num_examples_generated = tf.cast(
tf.shape(input=activations2)[0], tf.float64)
sigma_w = (num_examples_generated / (num_examples_generated - 1) *
tfp.stats.covariance(activations2),)
# m, m_w, sigma, sigma_w are tuples containing one or two elements: the first
# element will be used to calculate the score value and the second will be
# used to create the update_op. We apply the same operation on the two
# elements to make sure their value is consistent.
def _calculate_fid(m, m_w, sigma, sigma_w):
"""Returns the Frechet distance given the sample mean and covariance."""
# Find the Tr(sqrt(sigma sigma_w)) component of FID
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
# Compute the two components of FID.
# First the covariance component.
# Here, note that trace(A + B) = trace(A) + trace(B)
trace = tf.linalg.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
# Next the distance between means.
mean = tf.reduce_sum(input_tensor=tf.math.squared_difference(
m, m_w)) # Equivalent to L2 but more stable.
fid = trace + mean
if activations_dtype != tf.float64:
fid = tf.cast(fid, activations_dtype)
return fid
result = tuple(
_calculate_fid(m_val, m_w_val, sigma_val, sigma_w_val)
for m_val, m_w_val, sigma_val, sigma_w_val in zip(m, m_w, sigma, sigma_w))
if streaming:
return result
else:
return result[0]