in tensorflow_gan/python/estimator/tpu_gan_estimator.py [0:0]
def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs,
prepare_arguments_for_eval_metric_fn,
get_eval_metric_ops_fn, add_summaries):
"""Estimator spec for eval case."""
assert len(gan_model_fns) == 1, (
'`gan_models` must be length 1 in eval mode. Got length %d' %
len(gan_model_fns))
gan_model = gan_model_fns[0]()
_maybe_add_summaries(gan_model, add_summaries)
# Eval losses for metrics must preserve batch dimension.
kwargs = gan_loss_kwargs or {}
gan_loss_no_reduction = tfgan_train.gan_loss(
gan_model,
loss_fns.g_loss_fn,
loss_fns.d_loss_fn,
add_summaries=add_summaries,
reduction=tf.compat.v1.losses.Reduction.NONE,
**kwargs)
if prepare_arguments_for_eval_metric_fn is None:
# Set the default prepare_arguments_for_eval_metric_fn value: a function
# that returns its arguments in a dict.
prepare_arguments_for_eval_metric_fn = lambda **kwargs: kwargs
default_metric_fn = _make_default_metric_fn()
# Prepare tensors needed for calculating the metrics: the first element in
# `tensors_for_metric_fn` holds a dict containing the arguments for
# `default_metric_fn`, and the second element holds a dict for arguments for
# `get_eval_metric_ops_fn` (if it is not None).
tensors_for_metric_fn = [_make_default_metric_tensors(gan_loss_no_reduction)]
if get_eval_metric_ops_fn is not None:
tensors_for_metric_fn.append(prepare_arguments_for_eval_metric_fn(
**_make_custom_metric_tensors(gan_model)))
scalar_loss = tf.compat.v1.losses.compute_weighted_loss(
gan_loss_no_reduction.discriminator_loss,
loss_collection=None,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
# TPUEstimatorSpec.eval_metrics expects a function and a list of tensors,
# however, some sturctures in tensors_for_metric_fn might be dictionaries
# (e.g., generator_inputs and real_data). We therefore need to flatten
# tensors_for_metric_fn before passing them to the function and then restoring
# the original structure inside the function.
def _metric_fn_wrapper(*args):
"""Unflattens the arguments and pass them to the metric functions."""
unpacked_arguments = tf.nest.pack_sequence_as(tensors_for_metric_fn, args)
# Calculate default metrics.
metrics = default_metric_fn(**unpacked_arguments[0])
if get_eval_metric_ops_fn is not None:
# Append custom metrics.
custom_eval_metric_ops = get_eval_metric_ops_fn(**unpacked_arguments[1])
if not isinstance(custom_eval_metric_ops, dict):
raise TypeError('`get_eval_metric_ops_fn` must return a dict, '
'received: {}'.format(custom_eval_metric_ops))
metrics.update(custom_eval_metric_ops)
return metrics
flat_tensors = tf.nest.flatten(tensors_for_metric_fn)
if not all(isinstance(t, tf.Tensor) for t in flat_tensors):
raise ValueError('All objects nested within the TF-GAN model must be '
'tensors. Instead, types are: %s.' %
str([type(v) for v in flat_tensors]))
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
predictions=_predictions_from_generator_output(gan_model.generated_data),
loss=scalar_loss,
eval_metrics=(_metric_fn_wrapper, flat_tensors))