in tensorflow_gan/python/estimator/stargan_estimator.py [0:0]
def __init__(self,
model_dir=None,
generator_fn=None,
discriminator_fn=None,
loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None,
params=None):
"""Initializes a StarGANEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `input_data`. Outputs
a Tensor in the range [-inf, inf]. See `TFGAN` for more details and
examples.
loss_fn: The loss function on the generator. Takes a `StarGANModel`
namedtuple and return a `GANLoss` namedtuple.
generator_optimizer: The optimizer for generator updates, or a function
that takes no arguments and returns an optimizer. This function will be
called when the default graph is the `StarGANEstimator`'s graph, so
utilities like `tf.train.get_or_create_global_step` will
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
dict of metric results keyed by name. The output of this function is
passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
params: Optional `dict` of hyperparameters. Will receive what is passed
to Estimator in `params` parameter. This allows to configure Estimators
from hyper parameter tuning.
Raises:
ValueError: If loss functions aren't callable.
ValueError: If `use_loss_summaries` isn't boolean or `None`.
ValueError: If `get_hooks_fn` isn't callable or `None`.
"""
if not callable(loss_fn):
raise ValueError('loss_fn must be callable.')
if use_loss_summaries not in [True, False, None]:
raise ValueError('use_loss_summaries must be True, False or None.')
if get_hooks_fn is not None and not callable(get_hooks_fn):
raise TypeError('get_hooks_fn must be callable.')
def _model_fn(features, labels, mode, params):
"""StarGANEstimator model function."""
del params # unused
if mode not in [
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT
]:
raise ValueError('Mode not recognized: %s' % mode)
if mode == tf.estimator.ModeKeys.PREDICT:
input_data = features[0]
input_data_domain_label = features[1]
else:
input_data = features # rename inputs for clarity
input_data_domain_label = labels # rename inputs for clarity
# Make StarGANModel, which encapsulates the GAN model architectures.
gan_model = get_gan_model(mode, generator_fn, discriminator_fn,
input_data, input_data_domain_label,
add_summaries)
# Make the EstimatorSpec, which incorporates the StarGANModel, losses,
# eval, metrics, and optimizers (if required).
return get_estimator_spec(mode, gan_model, loss_fn,
get_eval_metric_ops_fn, generator_optimizer,
discriminator_optimizer, get_hooks_fn)
super(StarGANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config, params=params)