def get_eval_estimator_spec()

in tensorflow_gan/python/estimator/tpu_gan_estimator.py [0:0]


def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs,
                            prepare_arguments_for_eval_metric_fn,
                            get_eval_metric_ops_fn, add_summaries):
  """Estimator spec for eval case."""
  assert len(gan_model_fns) == 1, (
      '`gan_models` must be length 1 in eval mode. Got length %d' %
      len(gan_model_fns))

  gan_model = gan_model_fns[0]()

  _maybe_add_summaries(gan_model, add_summaries)

  # Eval losses for metrics must preserve batch dimension.
  kwargs = gan_loss_kwargs or {}
  gan_loss_no_reduction = tfgan_train.gan_loss(
      gan_model,
      loss_fns.g_loss_fn,
      loss_fns.d_loss_fn,
      add_summaries=add_summaries,
      reduction=tf.compat.v1.losses.Reduction.NONE,
      **kwargs)

  if prepare_arguments_for_eval_metric_fn is None:
    # Set the default prepare_arguments_for_eval_metric_fn value: a function
    # that returns its arguments in a dict.
    prepare_arguments_for_eval_metric_fn = lambda **kwargs: kwargs

  default_metric_fn = _make_default_metric_fn()
  # Prepare tensors needed for calculating the metrics: the first element in
  # `tensors_for_metric_fn` holds a dict containing the arguments for
  # `default_metric_fn`, and the second element holds a dict for arguments for
  # `get_eval_metric_ops_fn` (if it is not None).
  tensors_for_metric_fn = [_make_default_metric_tensors(gan_loss_no_reduction)]
  if get_eval_metric_ops_fn is not None:
    tensors_for_metric_fn.append(prepare_arguments_for_eval_metric_fn(
        **_make_custom_metric_tensors(gan_model)))

  scalar_loss = tf.compat.v1.losses.compute_weighted_loss(
      gan_loss_no_reduction.discriminator_loss,
      loss_collection=None,
      reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

  # TPUEstimatorSpec.eval_metrics expects a function and a list of tensors,
  # however, some sturctures in tensors_for_metric_fn might be dictionaries
  # (e.g., generator_inputs and real_data). We therefore need to flatten
  # tensors_for_metric_fn before passing them to the function and then restoring
  # the original structure inside the function.
  def _metric_fn_wrapper(*args):
    """Unflattens the arguments and pass them to the metric functions."""
    unpacked_arguments = tf.nest.pack_sequence_as(tensors_for_metric_fn, args)
    # Calculate default metrics.
    metrics = default_metric_fn(**unpacked_arguments[0])
    if get_eval_metric_ops_fn is not None:
      # Append custom metrics.
      custom_eval_metric_ops = get_eval_metric_ops_fn(**unpacked_arguments[1])
      if not isinstance(custom_eval_metric_ops, dict):
        raise TypeError('`get_eval_metric_ops_fn` must return a dict, '
                        'received: {}'.format(custom_eval_metric_ops))
      metrics.update(custom_eval_metric_ops)

    return metrics

  flat_tensors = tf.nest.flatten(tensors_for_metric_fn)
  if not all(isinstance(t, tf.Tensor) for t in flat_tensors):
    raise ValueError('All objects nested within the TF-GAN model must be '
                     'tensors. Instead, types are: %s.' %
                     str([type(v) for v in flat_tensors]))
  return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
      mode=tf.estimator.ModeKeys.EVAL,
      predictions=_predictions_from_generator_output(gan_model.generated_data),
      loss=scalar_loss,
      eval_metrics=(_metric_fn_wrapper, flat_tensors))