def _classifier_score_from_logits_helper()

in tensorflow_gan/python/eval/classifier_metrics.py [0:0]


def _classifier_score_from_logits_helper(logits, streaming=False):
  """A helper function for evaluating the classifier score from logits."""
  logits = tf.convert_to_tensor(value=logits)
  logits.shape.assert_has_rank(2)

  # Use maximum precision for best results.
  logits_dtype = logits.dtype
  if logits_dtype != tf.float64:
    logits = tf.cast(logits, tf.float64)

  p = tf.nn.softmax(logits)
  if streaming:
    # Note: The following streaming mean operation assumes all instances of
    # logits have the same batch size.
    q_ops = eval_utils.streaming_mean_tensor_float64(
        tf.reduce_mean(input_tensor=p, axis=0))
    # kl = kl_divergence(p, logits, q)
    # = tf.reduce_sum(p * (tf.nn.log_softmax(logits) - tf.math.log(q)), axis=1)
    # = tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1)
    #   - tf.reduce_sum(p * tf.math.log(q), axis=1)
    # log_score = tf.reduce_mean(kl)
    # = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
    #   - tf.reduce_mean(tf.reduce_sum(p * tf.math.log(q), axis=1))
    # = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
    #   - tf.reduce_sum(tf.reduce_mean(p, axis=0) * tf.math.log(q))
    # = tf.reduce_mean(tf.reduce_sum(p * tf.nn.log_softmax(logits), axis=1))
    #   - tf.reduce_sum(q * tf.math.log(q))
    plogp_mean_ops = eval_utils.streaming_mean_tensor_float64(
        tf.reduce_mean(
            input_tensor=tf.reduce_sum(
                input_tensor=p * tf.nn.log_softmax(logits), axis=1)))
    log_score_ops = tuple(
        plogp_mean_val - tf.reduce_sum(input_tensor=q_val * tf.math.log(q_val))
        for plogp_mean_val, q_val in zip(plogp_mean_ops, q_ops))
  else:
    q = tf.reduce_mean(input_tensor=p, axis=0)
    kl = kl_divergence(p, logits, q)
    kl.shape.assert_has_rank(1)
    log_score_ops = (tf.reduce_mean(input_tensor=kl),)
  # log_score_ops contains the score value and possibly the update_op. We
  # apply the same operation on all its elements to make sure their value is
  # consistent.
  final_score_tuple = tuple(tf.exp(value) for value in log_score_ops)
  if logits_dtype != tf.float64:
    final_score_tuple = tuple(
        tf.cast(value, logits_dtype) for value in final_score_tuple)

  if streaming:
    return final_score_tuple
  else:
    return final_score_tuple[0]