def gan_train_ops()

in tensorflow_gan/python/train.py [0:0]


def gan_train_ops(
    model,
    loss,
    generator_optimizer,
    discriminator_optimizer,
    check_for_unused_update_ops=True,
    is_chief=True,
    # Optional args to pass directly to the `create_train_op`.
    **kwargs):
  """Returns GAN train ops.

  The highest-level call in TF-GAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
  if isinstance(model, namedtuples.CycleGANModel):
    # Get and store all arguments other than model and loss from locals.
    # Contents of locals should not be modified, may not affect values. So make
    # a copy. https://docs.python.org/2/library/functions.html#locals.
    saved_params = dict(locals())
    saved_params.pop('model', None)
    saved_params.pop('loss', None)
    kwargs = saved_params.pop('kwargs', {})
    saved_params.update(kwargs)
    with tf.compat.v1.name_scope('cyclegan_x2y_train'):
      train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                    **saved_params)
    with tf.compat.v1.name_scope('cyclegan_y2x_train'):
      train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                    **saved_params)
    return namedtuples.GANTrainOps(
        (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
        (train_ops_x2y.discriminator_train_op,
         train_ops_y2x.discriminator_train_op),
        tf.compat.v1.train.get_or_create_global_step().assign_add(1))

  # Create global step increment op.
  global_step = tf.compat.v1.train.get_or_create_global_step()
  global_step_inc = global_step.assign_add(1)

  # Get generator and discriminator update ops. We split them so that update
  # ops aren't accidentally run multiple times. For now, throw an error if
  # there are update ops that aren't associated with either the generator or
  # the discriminator. Might modify the `kwargs` dictionary.
  gen_update_ops, dis_update_ops = _get_update_ops(
      kwargs, model.generator_scope.name, model.discriminator_scope.name,
      check_for_unused_update_ops)

  # Get the sync hooks if these are needed.
  sync_hooks = []

  generator_global_step = None
  if isinstance(generator_optimizer, tf.compat.v1.train.SyncReplicasOptimizer):
    # TODO(joelshor): Figure out a way to get this work without including the
    # dummy global step in the checkpoint.
    # WARNING: Making this variable a local variable causes sync replicas to
    # hang forever.
    generator_global_step = tf.compat.v1.get_variable(
        'dummy_global_step_generator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=tf.compat.v1.initializers.zeros(),
        trainable=False,
        collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
    gen_update_ops += [generator_global_step.assign(global_step)]
    sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
  with tf.compat.v1.name_scope('generator_train'):
    gen_train_op = contrib.create_train_op(
        total_loss=loss.generator_loss,
        optimizer=generator_optimizer,
        variables_to_train=model.generator_variables,
        global_step=generator_global_step,
        update_ops=gen_update_ops,
        **kwargs)

  discriminator_global_step = None
  if isinstance(discriminator_optimizer,
                tf.compat.v1.train.SyncReplicasOptimizer):
    # See comment above `generator_global_step`.
    discriminator_global_step = tf.compat.v1.get_variable(
        'dummy_global_step_discriminator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=tf.compat.v1.initializers.zeros(),
        trainable=False,
        collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
    dis_update_ops += [discriminator_global_step.assign(global_step)]
    sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
  with tf.compat.v1.name_scope('discriminator_train'):
    disc_train_op = contrib.create_train_op(
        total_loss=loss.discriminator_loss,
        optimizer=discriminator_optimizer,
        variables_to_train=model.discriminator_variables,
        global_step=discriminator_global_step,
        update_ops=dis_update_ops,
        **kwargs)

  return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
                                 sync_hooks)