def _get_train_op()

in tensorflow_gan/python/estimator/tpu_gan_estimator.py [0:0]


def _get_train_op(gan_model_fns, loss_fns, gan_loss_kwargs, optimizers,
                  joint_train, gan_train_steps, add_summaries):
  """Return a train op for TPU training."""

  def update_ops(gan_model):
    """Get generator and discriminator update ops for a single training substep.

    We split up the generator and discriminator update ops so that they aren't
    accidentally run multiple times. For now, throw an error if there are update
    ops that aren't associated with either the generator or the discriminator.
    Might modify the `kwargs` dictionary.

    Args:
      gan_model: The GANModel tuple.

    Returns:
       A tuple of lists corresponding to
       (generator_update_ops, discriminator_update_ops).
    """
    return tfgan_train._get_update_ops(  # pylint:disable=protected-access
        {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name)

  def gen_train_op(gan_model, gan_loss):
    """Get the generator train op for a single training substep.

    Args:
      gan_model: The GANModel tuple.
      gan_loss: The GANLoss tuple.

    Returns:
      An Op that performs a single generator training substep.
    """
    with tf.compat.v1.name_scope('generator_train'):
      return contrib.create_train_op(
          total_loss=gan_loss.generator_loss,
          optimizer=optimizers.gopt,
          variables_to_train=gan_model.generator_variables,
          global_step=None,
          update_ops=update_ops(gan_model)[0])

  def dis_train_op(gan_model, gan_loss):
    """Get the discriminator train op for a single training substep.

    Args:
      gan_model: The GANModel tuple.
      gan_loss: The GANLoss tuple.

    Returns:
      An Op that performs a single discriminator training substep.
    """
    with tf.compat.v1.name_scope('discriminator_train'):
      return contrib.create_train_op(
          total_loss=gan_loss.discriminator_loss,
          optimizer=optimizers.dopt,
          variables_to_train=gan_model.discriminator_variables,
          global_step=None,
          update_ops=update_ops(gan_model)[1])

  # Either optimize the generator and discriminator sequentially or jointly.
  g_steps = gan_train_steps.generator_train_steps
  d_steps = gan_train_steps.discriminator_train_steps
  joint_steps = 0
  if joint_train:
    joint_steps = min(g_steps, d_steps)
    g_steps -= joint_steps
    d_steps -= joint_steps
  total_steps = joint_steps + d_steps + g_steps

  prev_op = tf.no_op()
  scalar_loss = 0
  for i in range(total_steps):
    # For each substep, make sure that the forward pass ops are created with
    # control dependencies on the train op of the previous substep. We can't
    # just chain the train ops because the weight read for substep n will end up
    # happening before the weights are updated in substep n-1.
    with tf.control_dependencies([prev_op]):
      gan_model = gan_model_fns[i]()
      _maybe_add_summaries(gan_model, add_summaries and i == total_steps - 1)
      gan_loss = _get_loss_for_train(gan_model, loss_fns, gan_loss_kwargs,
                                     add_summaries)
      scalar_loss = gan_loss.discriminator_loss
      if i < joint_steps:
        prev_op = tf.group(
            dis_train_op(gan_model, gan_loss),
            gen_train_op(gan_model, gan_loss),
            name='joint_train_%d' % i)
      elif i < joint_steps + d_steps:
        prev_op = dis_train_op(gan_model, gan_loss)
      else:
        prev_op = gen_train_op(gan_model, gan_loss)

  with tf.control_dependencies([prev_op]):
    global_step = tf.compat.v1.train.get_or_create_global_step()
    train_op = global_step.assign_add(1)

  return train_op, scalar_loss