def __init__()

in tensorflow_gan/python/estimator/tpu_gan_estimator.py [0:0]


  def __init__(
      self,
      # Arguments to construct the `model_fn`.
      generator_fn=None,
      discriminator_fn=None,
      generator_loss_fn=None,
      discriminator_loss_fn=None,
      generator_optimizer=None,
      discriminator_optimizer=None,
      prepare_arguments_for_eval_metric_fn=None,
      get_eval_metric_ops_fn=None,
      add_summaries=None,
      joint_train=False,
      gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
      # TPUEstimator options.
      model_dir=None,
      config=None,
      params=None,
      use_tpu=True,
      train_batch_size=None,
      eval_batch_size=None,
      predict_batch_size=None,
      batch_axis=None,
      eval_on_tpu=True,
      export_to_tpu=True,
      warm_start_from=None):
    """Initializes a TPUGANEstimator instance.

    Args:
      generator_fn: A python function that takes a Tensor, Tensor list, or
        Tensor dictionary as inputs and returns the outputs of the GAN
        generator. See `TFGAN` for more details and examples. Additionally, if
        it has an argument called `mode`, the Estimator's `mode` will be passed
        in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
        normalization.
      discriminator_fn: A python function that takes the output of
        `generator_fn` or real data in the GAN setup, and `generator_inputs`.
        Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
        and examples.
      generator_loss_fn: The loss function on the generator. Takes a `GANModel`
        tuple.
      discriminator_loss_fn: The loss function on the discriminator. Takes a
        `GANModel` tuple.
      generator_optimizer: The optimizer for generator updates, or a function
        that takes no arguments and returns an optimizer. This function will
        be called when the default graph is the `GANEstimator`'s graph, so
        utilities like `tf.train.get_or_create_global_step` will
        work.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      prepare_arguments_for_eval_metric_fn: A function that takes a list of
        arguments and returns a nested structure of tensors keyed by name. The
        returned tensors must be compatible with TPUEstimatorSpec.eval_metrics
        (i.e., in batch-major format, where the batch size is the first
        dimension) and will be passed to the provided get_eval_metric_ops_fn.
        The arguments must be:
            * generator_inputs
            * generated_data
            * real_data
            * discriminator_real_outputs
            * discriminator_gen_outputs
        The default impelementation simply returns the arguments as-is. This
        function is executed on the TPU, allowing for compute-heavy eval-only
        operations to be performed.
      get_eval_metric_ops_fn: A function that takes a list of arguments and
        returns a dict of metric results keyed by name, exectuted on CPU. The
        arguments of the function should be the keys of the dict returned
        by prepare_arguments_for_eval_metric_fn (see the
        prepare_arguments_for_eval_metric_fn for the defaults), and should
        return a dict from metric string name to the result of calling a metric
        function, namely a (metric_tensor, update_op) tuple.
      add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
        This is ignored for jobs that run on TPU, such as the train job if
        `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`.
      joint_train: A Python boolean. If `True`, jointly train the generator and
        the discriminator. If `False`, sequentially train them. See `train.py`
        in TFGAN for more details on the differences between the two GAN
        training methods.
      gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio
        of generator to discriminator steps.
      model_dir: Same as `TPUEstimator`: Directory to save model parameters,
        graph and etc. This can also be used to load checkpoints from the
        directory into a estimator to continue training a previously saved
        model. If `None`, the model_dir in `config` will be used if set. If both
        are set, they must be same. If both are `None`, a temporary directory
        will be used.
      config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration
        object. Cannot be `None`.
      params: Same as `TPUEstimator`: An optional `dict` of hyper parameters
        that will be passed into `input_fn` and `model_fn`.  Keys are names of
        parameters, values are basic python types. There are reserved keys for
        `TPUEstimator`, including 'batch_size'. If any `params` are args to
        TF-GAN's `gan_loss`, they will be passed to `gan_loss` during training
        and evaluation.
      use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is
        enabled. Currently, TPU training and evaluation respect this bit, but
        eval_on_tpu can override execution of eval. See below. Predict still
        happens on CPU.
      train_batch_size: Same as `TPUEstimator`: An int representing the global
        training batch size. TPUEstimator transforms this global batch size to a
        per-shard batch size, as params['batch_size'], when calling `input_fn`
        and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be
        divisible by total number of replicas.
      eval_batch_size: Same as `TPUEstimator`: An int representing evaluation
        batch size. Must be divisible by total number of replicas.
      predict_batch_size: Same as `TPUEstimator`: An int representing the
        prediction batch size. Must be divisible by total number of replicas.
      batch_axis: Same as `TPUEstimator`: A python tuple of int values
        describing how each tensor produced by the Estimator `input_fn` should
        be split across the TPU compute shards. For example, if your input_fn
        produced (images, labels) where the images tensor is in `HWCN` format,
        your shard dimensions would be [3, 0], where 3 corresponds to the `N`
        dimension of your images Tensor, and 0 corresponds to the dimension
        along which to split the labels to match up with the corresponding
        images. If None is supplied, and per_host_input_for_training is True,
        batches will be sharded based on the major dimension. If
        tpu_config.per_host_input_for_training is False or `PER_HOST_V2`,
        batch_axis is ignored.
      eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or
        GPU. In this case, the model_fn must return `EstimatorSpec` when called
        with `mode` as `EVAL`.
      export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()`
        exports a metagraph for serving on TPU besides the one on CPU.
      warm_start_from: Same as `TPUEstimator`: Optional string filepath to a
        checkpoint or SavedModel to warm-start from, or a
        `tf.estimator.WarmStartSettings` object to fully configure
        warm-starting.  If the string filepath is provided instead of a
        `WarmStartSettings`, then all variables are warm-started, and it is
        assumed that vocabularies and Tensor names are unchanged.

    Raises:
      ValueError: If loss functions aren't callable.
      ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps`
        tuple.
      ValueError: If `gan_train_steps` isn't 1:1 training.
    """
    _validate_input_args(
        generator_loss_fn, discriminator_loss_fn, gan_train_steps)
    loss_fns = LossFns(generator_loss_fn, discriminator_loss_fn)
    optimizers = Optimizers(generator_optimizer, discriminator_optimizer)

    # Determine the number of GAN models required to create in order to train
    # in different D:G ratios on TPU.
    required_train_models = _required_train_models(gan_train_steps, joint_train)
    effective_train_batch_size = required_train_models * train_batch_size

    def _model_fn(features, labels, mode, params):
      """GANEstimator model function."""
      if mode not in [
          tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
          tf.estimator.ModeKeys.PREDICT
      ]:
        raise ValueError('Mode not recognized: %s' % mode)
      real_data = labels  # rename inputs for clarity
      generator_inputs = features  # rename inputs for clarity

      # Collect GANModel builder functions, which encapsulate the GAN model
      # architectures. Don't actually execute them here, since the functions
      # actually create the TF ops and the variable reads need to be chained
      # after the writes from the previous step. Instead just pass the functions
      # with bound arguments down so that they can easily be executed later.
      gan_model_fns = _get_gan_model_fns(
          mode,
          generator_fn,
          discriminator_fn,
          real_data,
          generator_inputs,
          num_train_models=required_train_models)

      # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then
      # remove `add_summaries` logic below.
      is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu)
      summary_types = None if is_on_tpu else add_summaries

      # Make the TPUEstimatorSpec, which incorporates the model, losses, eval
      # metrics, and optimizers (if required).
      gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params(params)
      if mode == tf.estimator.ModeKeys.TRAIN:
        estimator_spec = get_train_estimator_spec(
            gan_model_fns,
            loss_fns,
            gan_loss_kwargs,
            optimizers,
            joint_train,
            is_on_tpu,
            gan_train_steps,
            add_summaries=summary_types)
      elif mode == tf.estimator.ModeKeys.EVAL:
        estimator_spec = get_eval_estimator_spec(
            gan_model_fns,
            loss_fns,
            gan_loss_kwargs,
            prepare_arguments_for_eval_metric_fn,
            get_eval_metric_ops_fn,
            add_summaries=summary_types)
      else:  # predict
        estimator_spec = get_predict_estimator_spec(gan_model_fns)
      assert isinstance(estimator_spec,
                        tf.compat.v1.estimator.tpu.TPUEstimatorSpec)

      return estimator_spec

    super(TPUGANEstimator, self).__init__(
        model_fn=_model_fn,
        model_dir=model_dir,
        config=config,
        params=params,
        use_tpu=use_tpu,
        train_batch_size=effective_train_batch_size,
        eval_batch_size=eval_batch_size,
        predict_batch_size=predict_batch_size,
        batch_axis=batch_axis,
        eval_on_tpu=eval_on_tpu,
        export_to_tpu=export_to_tpu,
        warm_start_from=warm_start_from)