in tensorflow_gan/python/train.py [0:0]
def gan_train_ops(
model,
loss,
generator_optimizer,
discriminator_optimizer,
check_for_unused_update_ops=True,
is_chief=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
"""Returns GAN train ops.
The highest-level call in TF-GAN. It is composed of functions that can also
be called, should a user require more control over some part of the GAN
training process.
Args:
model: A GANModel.
loss: A GANLoss.
generator_optimizer: The optimizer for generator updates.
discriminator_optimizer: The optimizer for the discriminator updates.
check_for_unused_update_ops: If `True`, throws an exception if there are
update ops outside of the generator or discriminator scopes.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
**kwargs: Keyword args to pass directly to
`training.create_train_op` for both the generator and
discriminator train op.
Returns:
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
be used to train a generator/discriminator pair.
"""
if isinstance(model, namedtuples.CycleGANModel):
# Get and store all arguments other than model and loss from locals.
# Contents of locals should not be modified, may not affect values. So make
# a copy. https://docs.python.org/2/library/functions.html#locals.
saved_params = dict(locals())
saved_params.pop('model', None)
saved_params.pop('loss', None)
kwargs = saved_params.pop('kwargs', {})
saved_params.update(kwargs)
with tf.compat.v1.name_scope('cyclegan_x2y_train'):
train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
**saved_params)
with tf.compat.v1.name_scope('cyclegan_y2x_train'):
train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
**saved_params)
return namedtuples.GANTrainOps(
(train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
(train_ops_x2y.discriminator_train_op,
train_ops_y2x.discriminator_train_op),
tf.compat.v1.train.get_or_create_global_step().assign_add(1))
# Create global step increment op.
global_step = tf.compat.v1.train.get_or_create_global_step()
global_step_inc = global_step.assign_add(1)
# Get generator and discriminator update ops. We split them so that update
# ops aren't accidentally run multiple times. For now, throw an error if
# there are update ops that aren't associated with either the generator or
# the discriminator. Might modify the `kwargs` dictionary.
gen_update_ops, dis_update_ops = _get_update_ops(
kwargs, model.generator_scope.name, model.discriminator_scope.name,
check_for_unused_update_ops)
# Get the sync hooks if these are needed.
sync_hooks = []
generator_global_step = None
if isinstance(generator_optimizer, tf.compat.v1.train.SyncReplicasOptimizer):
# TODO(joelshor): Figure out a way to get this work without including the
# dummy global step in the checkpoint.
# WARNING: Making this variable a local variable causes sync replicas to
# hang forever.
generator_global_step = tf.compat.v1.get_variable(
'dummy_global_step_generator',
shape=[],
dtype=global_step.dtype.base_dtype,
initializer=tf.compat.v1.initializers.zeros(),
trainable=False,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
gen_update_ops += [generator_global_step.assign(global_step)]
sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
with tf.compat.v1.name_scope('generator_train'):
gen_train_op = contrib.create_train_op(
total_loss=loss.generator_loss,
optimizer=generator_optimizer,
variables_to_train=model.generator_variables,
global_step=generator_global_step,
update_ops=gen_update_ops,
**kwargs)
discriminator_global_step = None
if isinstance(discriminator_optimizer,
tf.compat.v1.train.SyncReplicasOptimizer):
# See comment above `generator_global_step`.
discriminator_global_step = tf.compat.v1.get_variable(
'dummy_global_step_discriminator',
shape=[],
dtype=global_step.dtype.base_dtype,
initializer=tf.compat.v1.initializers.zeros(),
trainable=False,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES])
dis_update_ops += [discriminator_global_step.assign(global_step)]
sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
with tf.compat.v1.name_scope('discriminator_train'):
disc_train_op = contrib.create_train_op(
total_loss=loss.discriminator_loss,
optimizer=discriminator_optimizer,
variables_to_train=model.discriminator_variables,
global_step=discriminator_global_step,
update_ops=dis_update_ops,
**kwargs)
return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
sync_hooks)