def __init__()

in tensorflow_gan/python/estimator/stargan_estimator.py [0:0]


  def __init__(self,
               model_dir=None,
               generator_fn=None,
               discriminator_fn=None,
               loss_fn=None,
               generator_optimizer=None,
               discriminator_optimizer=None,
               get_hooks_fn=None,
               get_eval_metric_ops_fn=None,
               add_summaries=None,
               use_loss_summaries=True,
               config=None,
               params=None):
    """Initializes a StarGANEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      generator_fn: A python function that takes a Tensor, Tensor list, or
        Tensor dictionary as inputs and returns the outputs of the GAN
        generator. See `TFGAN` for more details and examples. Additionally, if
        it has an argument called `mode`, the Estimator's `mode` will be passed
        in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
        normalization.
      discriminator_fn: A python function that takes the output of
        `generator_fn` or real data in the GAN setup, and `input_data`. Outputs
        a Tensor in the range [-inf, inf]. See `TFGAN` for more details and
        examples.
      loss_fn: The loss function on the generator. Takes a `StarGANModel`
        namedtuple and return a `GANLoss` namedtuple.
      generator_optimizer: The optimizer for generator updates, or a function
        that takes no arguments and returns an optimizer. This function will be
        called when the default graph is the `StarGANEstimator`'s graph, so
        utilities like `tf.train.get_or_create_global_step` will
        work.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
        list of hooks. These hooks are run on the generator and discriminator
        train ops, and can be used to implement the GAN training scheme.
        Defaults to `train.get_sequential_train_hooks()`.
      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
        dict of metric results keyed by name. The output of this function is
        passed into `tf.estimator.EstimatorSpec` during evaluation.
      add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      config: `RunConfig` object to configure the runtime settings.
      params: Optional `dict` of hyperparameters.  Will receive what is passed
        to Estimator in `params` parameter. This allows to configure Estimators
        from hyper parameter tuning.

    Raises:
      ValueError: If loss functions aren't callable.
      ValueError: If `use_loss_summaries` isn't boolean or `None`.
      ValueError: If `get_hooks_fn` isn't callable or `None`.
    """
    if not callable(loss_fn):
      raise ValueError('loss_fn must be callable.')
    if use_loss_summaries not in [True, False, None]:
      raise ValueError('use_loss_summaries must be True, False or None.')
    if get_hooks_fn is not None and not callable(get_hooks_fn):
      raise TypeError('get_hooks_fn must be callable.')

    def _model_fn(features, labels, mode, params):
      """StarGANEstimator model function."""
      del params  # unused
      if mode not in [
          tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
          tf.estimator.ModeKeys.PREDICT
      ]:
        raise ValueError('Mode not recognized: %s' % mode)

      if mode == tf.estimator.ModeKeys.PREDICT:
        input_data = features[0]
        input_data_domain_label = features[1]
      else:
        input_data = features  # rename inputs for clarity
        input_data_domain_label = labels  # rename inputs for clarity

      # Make StarGANModel, which encapsulates the GAN model architectures.
      gan_model = get_gan_model(mode, generator_fn, discriminator_fn,
                                input_data, input_data_domain_label,
                                add_summaries)

      # Make the EstimatorSpec, which incorporates the StarGANModel, losses,
      # eval, metrics, and optimizers (if required).
      return get_estimator_spec(mode, gan_model, loss_fn,
                                get_eval_metric_ops_fn, generator_optimizer,
                                discriminator_optimizer, get_hooks_fn)

    super(StarGANEstimator, self).__init__(
        model_fn=_model_fn, model_dir=model_dir, config=config, params=params)