in tensorflow_gan/python/train.py [0:0]
def train_step(sess, train_op, global_step, train_step_kwargs):
"""Function that takes a gradient step and specifies whether to stop.
Args:
sess: The current session.
train_op: An `Operation` that evaluates the gradients and returns the
total loss.
global_step: A `Tensor` representing the global training step.
train_step_kwargs: A dictionary of keyword arguments.
Returns:
The total loss and a boolean indicating whether or not to stop training.
Raises:
ValueError: If 'should_trace' is in `train_step_kwargs` but `logdir` is not.
"""
start_time = time.time()
trace_run_options = None
run_metadata = None
if 'should_trace' in train_step_kwargs:
if 'logdir' not in train_step_kwargs:
raise ValueError('logdir must be present in train_step_kwargs when '
'should_trace is present')
if sess.run(train_step_kwargs['should_trace']):
trace_run_options = tf.compat.v1.RunOptions(
trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
run_metadata = tf.compat.v1.RunMetadata()
total_loss, np_global_step = sess.run([train_op, global_step],
options=trace_run_options,
run_metadata=run_metadata)
time_elapsed = time.time() - start_time
if run_metadata is not None:
trace_filename = os.path.join(train_step_kwargs['logdir'],
'tf_trace-%d.json' % np_global_step)
tf.compat.v1.logging.info('Writing trace to %s', trace_filename)
if 'summary_writer' in train_step_kwargs:
train_step_kwargs['summary_writer'].add_run_metadata(run_metadata,
'run_metadata-%d' %
np_global_step)
if 'should_log' in train_step_kwargs:
if sess.run(train_step_kwargs['should_log']):
tf.compat.v1.logging.info('global step %d: loss = %.4f (%.3f sec/step)',
np_global_step, total_loss, time_elapsed)
# TODO(joelshor): Figure out why we can't put this into sess.run. The
# issue right now is that the stop check depends on the global step. The
# increment of global step often happens via the train op, which used
# created using optimizer.apply_gradients.
#
# Since running `train_op` causes the global step to be incremented, one
# would expected that using a control dependency would allow the
# should_stop check to be run in the same session.run call:
#
# with ops.control_dependencies([train_op]):
# should_stop_op = ...
#
# However, this actually seems not to work on certain platforms.
if 'should_stop' in train_step_kwargs:
should_stop = sess.run(train_step_kwargs['should_stop'])
else:
should_stop = False
return total_loss, should_stop