in tensorflow_gan/python/contrib_utils.py [0:0]
def create_train_op(total_loss,
optimizer,
global_step=_USE_GLOBAL_STEP,
update_ops=None,
variables_to_train=None,
transform_grads_fn=None,
summarize_gradients=False,
gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
check_numerics=True):
"""Creates an `Operation` that evaluates the gradients and returns the loss.
Args:
total_loss: A `Tensor` representing the total loss.
optimizer: A tf.Optimizer to use for computing the gradients.
global_step: A `Tensor` representing the global step variable. If left as
`_USE_GLOBAL_STEP`, then tf.train.global_step() is used.
update_ops: An optional list of updates to execute. If `update_ops` is
`None`, then the update ops are set to the contents of the
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
a warning will be displayed.
variables_to_train: an optional list of variables to train. If None, it will
default to all tf.trainable_variables().
transform_grads_fn: A function which takes a single argument, a list of
gradient to variable pairs (tuples), performs any requested gradient
updates, such as gradient clipping or multipliers, and returns the updated
list.
summarize_gradients: Whether or not add summaries for each gradient.
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: Whether or not to try colocating the gradients
with the ops that generated them.
check_numerics: Whether or not we apply check_numerics.
Returns:
A `Tensor` that when evaluated, computes the gradients and returns the total
loss value.
"""
if global_step is _USE_GLOBAL_STEP:
global_step = tf.compat.v1.train.get_or_create_global_step()
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
global_update_ops = set(
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
if update_ops is None:
update_ops = global_update_ops
else:
update_ops = set(update_ops)
if not global_update_ops.issubset(update_ops):
tf.compat.v1.logging.warning(
'update_ops in create_train_op does not contain all the '
'update_ops in GraphKeys.UPDATE_OPS')
# Make sure update_ops are computed before total_loss.
if update_ops:
with tf.control_dependencies(update_ops):
barrier = tf.no_op(name='update_barrier')
total_loss = _with_dependencies([barrier], total_loss)
if variables_to_train is None:
# Default to tf.trainable_variables()
variables_to_train = tf.compat.v1.trainable_variables()
else:
# Make sure that variables_to_train are in tf.trainable_variables()
for v in variables_to_train:
assert v in tf.compat.v1.trainable_variables()
assert variables_to_train
# Create the gradients. Note that apply_gradients adds the gradient
# computation to the current graph.
grads = optimizer.compute_gradients(
total_loss,
variables_to_train,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops)
if transform_grads_fn:
grads = transform_grads_fn(grads)
# Summarize gradients.
if summarize_gradients:
with tf.compat.v1.name_scope('summarize_grads'):
add_gradients_summaries(grads)
# Create gradient updates.
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
with tf.compat.v1.name_scope('train_op'):
# Make sure total_loss is valid.
if check_numerics:
total_loss = tf.debugging.check_numerics(total_loss,
'LossTensor is inf or nan')
# Ensure the train_tensor computes grad_updates.
train_op = _with_dependencies([grad_updates], total_loss)
# Add the operation used for training to the 'train_op' collection
train_ops = tf.compat.v1.get_collection_ref(tf.compat.v1.GraphKeys.TRAIN_OP)
if train_op not in train_ops:
train_ops.append(train_op)
return train_op