in tensorflow_gan/python/eval/classifier_metrics.py [0:0]
def _classifier_score_from_logits_helper(logits, streaming=False):
"""A helper function for evaluating the classifier score from logits."""
logits = tf.convert_to_tensor(value=logits)
logits.shape.assert_has_rank(2)
# Use maximum precision for best results.
logits_dtype = logits.dtype
if logits_dtype != tf.float64:
logits = tf.cast(logits, tf.float64)
p = tf.nn.softmax(logits)
if streaming:
# Note: The following streaming mean operation assumes all instances of
# logits have the same batch size.
q_ops = eval_utils.streaming_mean_tensor_float64(
tf.reduce_mean(input_tensor=p, axis=0))
# kl = kl_divergence(p, logits, q)
# = tf.reduce_sum(p * (tf.nn.log_softmax(logits) - tf.math.log(q)), axis=1)
# = tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1)
# - tf.reduce_sum(p * tf.math.log(q), axis=1)
# log_score = tf.reduce_mean(kl)
# = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
# - tf.reduce_mean(tf.reduce_sum(p * tf.math.log(q), axis=1))
# = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
# - tf.reduce_sum(tf.reduce_mean(p, axis=0) * tf.math.log(q))
# = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
# - tf.reduce_sum(q * tf.math.log(q))
plogp_mean_ops = eval_utils.streaming_mean_tensor_float64(
tf.reduce_mean(
input_tensor=tf.reduce_sum(
input_tensor=p * tf.nn.log_softmax(logits), axis=1)))
log_score_ops = tuple(
plogp_mean_val - tf.reduce_sum(input_tensor=q_val * tf.math.log(q_val))
for plogp_mean_val, q_val in zip(plogp_mean_ops, q_ops))
else:
q = tf.reduce_mean(input_tensor=p, axis=0)
kl = kl_divergence(p, logits, q)
kl.shape.assert_has_rank(1)
log_score_ops = (tf.reduce_mean(input_tensor=kl),)
# log_score_ops contains the score value and possibly the update_op. We
# apply the same operation on all its elements to make sure their value is
# consistent.
final_score_tuple = tuple(tf.exp(value) for value in log_score_ops)
if logits_dtype != tf.float64:
final_score_tuple = tuple(
tf.cast(value, logits_dtype) for value in final_score_tuple)
if streaming:
return final_score_tuple
else:
return final_score_tuple[0]