in tensorflow_gan/python/estimator/tpu_gan_estimator.py [0:0]
def __init__(
self,
# Arguments to construct the `model_fn`.
generator_fn=None,
discriminator_fn=None,
generator_loss_fn=None,
discriminator_loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
prepare_arguments_for_eval_metric_fn=None,
get_eval_metric_ops_fn=None,
add_summaries=None,
joint_train=False,
gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
# TPUEstimator options.
model_dir=None,
config=None,
params=None,
use_tpu=True,
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
export_to_tpu=True,
warm_start_from=None):
"""Initializes a TPUGANEstimator instance.
Args:
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`GANModel` tuple.
generator_optimizer: The optimizer for generator updates, or a function
that takes no arguments and returns an optimizer. This function will
be called when the default graph is the `GANEstimator`'s graph, so
utilities like `tf.train.get_or_create_global_step` will
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
prepare_arguments_for_eval_metric_fn: A function that takes a list of
arguments and returns a nested structure of tensors keyed by name. The
returned tensors must be compatible with TPUEstimatorSpec.eval_metrics
(i.e., in batch-major format, where the batch size is the first
dimension) and will be passed to the provided get_eval_metric_ops_fn.
The arguments must be:
* generator_inputs
* generated_data
* real_data
* discriminator_real_outputs
* discriminator_gen_outputs
The default impelementation simply returns the arguments as-is. This
function is executed on the TPU, allowing for compute-heavy eval-only
operations to be performed.
get_eval_metric_ops_fn: A function that takes a list of arguments and
returns a dict of metric results keyed by name, exectuted on CPU. The
arguments of the function should be the keys of the dict returned
by prepare_arguments_for_eval_metric_fn (see the
prepare_arguments_for_eval_metric_fn for the defaults), and should
return a dict from metric string name to the result of calling a metric
function, namely a (metric_tensor, update_op) tuple.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
This is ignored for jobs that run on TPU, such as the train job if
`use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`.
joint_train: A Python boolean. If `True`, jointly train the generator and
the discriminator. If `False`, sequentially train them. See `train.py`
in TFGAN for more details on the differences between the two GAN
training methods.
gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio
of generator to discriminator steps.
model_dir: Same as `TPUEstimator`: Directory to save model parameters,
graph and etc. This can also be used to load checkpoints from the
directory into a estimator to continue training a previously saved
model. If `None`, the model_dir in `config` will be used if set. If both
are set, they must be same. If both are `None`, a temporary directory
will be used.
config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration
object. Cannot be `None`.
params: Same as `TPUEstimator`: An optional `dict` of hyper parameters
that will be passed into `input_fn` and `model_fn`. Keys are names of
parameters, values are basic python types. There are reserved keys for
`TPUEstimator`, including 'batch_size'. If any `params` are args to
TF-GAN's `gan_loss`, they will be passed to `gan_loss` during training
and evaluation.
use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is
enabled. Currently, TPU training and evaluation respect this bit, but
eval_on_tpu can override execution of eval. See below. Predict still
happens on CPU.
train_batch_size: Same as `TPUEstimator`: An int representing the global
training batch size. TPUEstimator transforms this global batch size to a
per-shard batch size, as params['batch_size'], when calling `input_fn`
and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be
divisible by total number of replicas.
eval_batch_size: Same as `TPUEstimator`: An int representing evaluation
batch size. Must be divisible by total number of replicas.
predict_batch_size: Same as `TPUEstimator`: An int representing the
prediction batch size. Must be divisible by total number of replicas.
batch_axis: Same as `TPUEstimator`: A python tuple of int values
describing how each tensor produced by the Estimator `input_fn` should
be split across the TPU compute shards. For example, if your input_fn
produced (images, labels) where the images tensor is in `HWCN` format,
your shard dimensions would be [3, 0], where 3 corresponds to the `N`
dimension of your images Tensor, and 0 corresponds to the dimension
along which to split the labels to match up with the corresponding
images. If None is supplied, and per_host_input_for_training is True,
batches will be sharded based on the major dimension. If
tpu_config.per_host_input_for_training is False or `PER_HOST_V2`,
batch_axis is ignored.
eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or
GPU. In this case, the model_fn must return `EstimatorSpec` when called
with `mode` as `EVAL`.
export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()`
exports a metagraph for serving on TPU besides the one on CPU.
warm_start_from: Same as `TPUEstimator`: Optional string filepath to a
checkpoint or SavedModel to warm-start from, or a
`tf.estimator.WarmStartSettings` object to fully configure
warm-starting. If the string filepath is provided instead of a
`WarmStartSettings`, then all variables are warm-started, and it is
assumed that vocabularies and Tensor names are unchanged.
Raises:
ValueError: If loss functions aren't callable.
ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps`
tuple.
ValueError: If `gan_train_steps` isn't 1:1 training.
"""
_validate_input_args(
generator_loss_fn, discriminator_loss_fn, gan_train_steps)
loss_fns = LossFns(generator_loss_fn, discriminator_loss_fn)
optimizers = Optimizers(generator_optimizer, discriminator_optimizer)
# Determine the number of GAN models required to create in order to train
# in different D:G ratios on TPU.
required_train_models = _required_train_models(gan_train_steps, joint_train)
effective_train_batch_size = required_train_models * train_batch_size
def _model_fn(features, labels, mode, params):
"""GANEstimator model function."""
if mode not in [
tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT
]:
raise ValueError('Mode not recognized: %s' % mode)
real_data = labels # rename inputs for clarity
generator_inputs = features # rename inputs for clarity
# Collect GANModel builder functions, which encapsulate the GAN model
# architectures. Don't actually execute them here, since the functions
# actually create the TF ops and the variable reads need to be chained
# after the writes from the previous step. Instead just pass the functions
# with bound arguments down so that they can easily be executed later.
gan_model_fns = _get_gan_model_fns(
mode,
generator_fn,
discriminator_fn,
real_data,
generator_inputs,
num_train_models=required_train_models)
# TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then
# remove `add_summaries` logic below.
is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu)
summary_types = None if is_on_tpu else add_summaries
# Make the TPUEstimatorSpec, which incorporates the model, losses, eval
# metrics, and optimizers (if required).
gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params(params)
if mode == tf.estimator.ModeKeys.TRAIN:
estimator_spec = get_train_estimator_spec(
gan_model_fns,
loss_fns,
gan_loss_kwargs,
optimizers,
joint_train,
is_on_tpu,
gan_train_steps,
add_summaries=summary_types)
elif mode == tf.estimator.ModeKeys.EVAL:
estimator_spec = get_eval_estimator_spec(
gan_model_fns,
loss_fns,
gan_loss_kwargs,
prepare_arguments_for_eval_metric_fn,
get_eval_metric_ops_fn,
add_summaries=summary_types)
else: # predict
estimator_spec = get_predict_estimator_spec(gan_model_fns)
assert isinstance(estimator_spec,
tf.compat.v1.estimator.tpu.TPUEstimatorSpec)
return estimator_spec
super(TPUGANEstimator, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
params=params,
use_tpu=use_tpu,
train_batch_size=effective_train_batch_size,
eval_batch_size=eval_batch_size,
predict_batch_size=predict_batch_size,
batch_axis=batch_axis,
eval_on_tpu=eval_on_tpu,
export_to_tpu=export_to_tpu,
warm_start_from=warm_start_from)