in tensorflow_gan/python/estimator/gan_estimator.py [0:0]
def __init__(self,
model_dir=None,
generator_fn=None,
discriminator_fn=None,
generator_loss_fn=None,
discriminator_loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None,
params=None,
warm_start_from=None,
is_chief=True):
"""Initializes a GANEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TF-GAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`GANModel` tuple.
generator_optimizer: The optimizer for generator updates, or a function
that takes no arguments and returns an optimizer. This function will
be called when the default graph is the `GANEstimator`'s graph, so
utilities like `tf.train.get_or_create_global_step` will
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
dict of metric results keyed by name. The output of this function is
passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
params: Optional `dict` of hyperparameters. Will receive what is passed
to Estimator in `params` parameter. This allows to configure Estimators
from hyper parameter tuning. If any `params` are args to TF-GAN's
`gan_loss`, they will be passed to `gan_loss` during training and
evaluation.
warm_start_from: A filepath to a checkpoint or saved model, or a
WarmStartSettings object to configure initialization.
is_chief: Whether or not this Estimator is running on a chief or worker.
Needs to be set appropriately if using SyncReplicasOptimizers.
Raises:
ValueError: If loss functions aren't callable.
ValueError: If `use_loss_summaries` isn't boolean or `None`.
ValueError: If `get_hooks_fn` isn't callable or `None`.
"""
_validate_input_args(generator_loss_fn, discriminator_loss_fn,
use_loss_summaries, get_hooks_fn)
optimizers = Optimizers(generator_optimizer, discriminator_optimizer)
def _model_fn(features, labels, mode, params):
"""GANEstimator model function."""
if mode not in [
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT
]:
raise ValueError('Mode not recognized: %s' % mode)
real_data = labels # rename inputs for clarity
generator_inputs = features # rename inputs for clarity
# Make GANModel, which encapsulates the GAN model architectures.
gan_model = get_gan_model(mode, generator_fn, discriminator_fn, real_data,
generator_inputs, add_summaries)
# Make GANLoss, which encapsulates the losses.
if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
gan_loss_kwargs = extract_gan_loss_args_from_params(params) or {}
gan_loss = tfgan_train.gan_loss(
gan_model,
generator_loss_fn,
discriminator_loss_fn,
add_summaries=use_loss_summaries,
**gan_loss_kwargs)
# Make the EstimatorSpec, which incorporates the GANModel, losses, eval
# metrics, and optimizers (if required).
if mode == tf.estimator.ModeKeys.TRAIN:
estimator_spec = get_train_estimator_spec(
gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=is_chief)
elif mode == tf.estimator.ModeKeys.EVAL:
estimator_spec = get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
else: # tf.estimator.ModeKeys.PREDICT
estimator_spec = get_predict_estimator_spec(gan_model)
return estimator_spec
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config, params=params,
warm_start_from=warm_start_from)