def __init__()

in tensorflow_gan/python/estimator/gan_estimator.py [0:0]


  def __init__(self,
               model_dir=None,
               generator_fn=None,
               discriminator_fn=None,
               generator_loss_fn=None,
               discriminator_loss_fn=None,
               generator_optimizer=None,
               discriminator_optimizer=None,
               get_hooks_fn=None,
               get_eval_metric_ops_fn=None,
               add_summaries=None,
               use_loss_summaries=True,
               config=None,
               params=None,
               warm_start_from=None,
               is_chief=True):
    """Initializes a GANEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator
        to continue training a previously saved model.
      generator_fn: A python function that takes a Tensor, Tensor list, or
        Tensor dictionary as inputs and returns the outputs of the GAN
        generator. See `TF-GAN` for more details and examples. Additionally, if
        it has an argument called `mode`, the Estimator's `mode` will be passed
        in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
        normalization.
      discriminator_fn: A python function that takes the output of
        `generator_fn` or real data in the GAN setup, and `generator_inputs`.
        Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
        and examples.
      generator_loss_fn: The loss function on the generator. Takes a `GANModel`
        tuple.
      discriminator_loss_fn: The loss function on the discriminator. Takes a
        `GANModel` tuple.
      generator_optimizer: The optimizer for generator updates, or a function
        that takes no arguments and returns an optimizer. This function will
        be called when the default graph is the `GANEstimator`'s graph, so
        utilities like `tf.train.get_or_create_global_step` will
        work.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
        list of hooks. These hooks are run on the generator and discriminator
        train ops, and can be used to implement the GAN training scheme.
        Defaults to `train.get_sequential_train_hooks()`.
      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
        dict of metric results keyed by name. The output of this function is
        passed into `tf.estimator.EstimatorSpec` during evaluation.
      add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      config: `RunConfig` object to configure the runtime settings.
      params: Optional `dict` of hyperparameters.  Will receive what is passed
        to Estimator in `params` parameter. This allows to configure Estimators
        from hyper parameter tuning. If any `params` are args to TF-GAN's
        `gan_loss`, they will be passed to `gan_loss` during training and
        evaluation.
      warm_start_from: A filepath to a checkpoint or saved model, or a
        WarmStartSettings object to configure initialization.
      is_chief: Whether or not this Estimator is running on a chief or worker.
        Needs to be set appropriately if using SyncReplicasOptimizers.

    Raises:
      ValueError: If loss functions aren't callable.
      ValueError: If `use_loss_summaries` isn't boolean or `None`.
      ValueError: If `get_hooks_fn` isn't callable or `None`.
    """
    _validate_input_args(generator_loss_fn, discriminator_loss_fn,
                         use_loss_summaries, get_hooks_fn)
    optimizers = Optimizers(generator_optimizer, discriminator_optimizer)

    def _model_fn(features, labels, mode, params):
      """GANEstimator model function."""
      if mode not in [
          tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
          tf.estimator.ModeKeys.PREDICT
      ]:
        raise ValueError('Mode not recognized: %s' % mode)
      real_data = labels  # rename inputs for clarity
      generator_inputs = features  # rename inputs for clarity

      # Make GANModel, which encapsulates the GAN model architectures.
      gan_model = get_gan_model(mode, generator_fn, discriminator_fn, real_data,
                                generator_inputs, add_summaries)

      # Make GANLoss, which encapsulates the losses.
      if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
        gan_loss_kwargs = extract_gan_loss_args_from_params(params) or {}
        gan_loss = tfgan_train.gan_loss(
            gan_model,
            generator_loss_fn,
            discriminator_loss_fn,
            add_summaries=use_loss_summaries,
            **gan_loss_kwargs)

      # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
      # metrics, and optimizers (if required).
      if mode == tf.estimator.ModeKeys.TRAIN:
        estimator_spec = get_train_estimator_spec(
            gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=is_chief)
      elif mode == tf.estimator.ModeKeys.EVAL:
        estimator_spec = get_eval_estimator_spec(
            gan_model, gan_loss, get_eval_metric_ops_fn)
      else:  # tf.estimator.ModeKeys.PREDICT
        estimator_spec = get_predict_estimator_spec(gan_model)

      return estimator_spec

    super(GANEstimator, self).__init__(
        model_fn=_model_fn, model_dir=model_dir, config=config, params=params,
        warm_start_from=warm_start_from)