in tensorflow_benchmark/tf_cnn_benchmarks/benchmark_cnn.py [0:0]
def _benchmark_cnn(self):
"""Run cnn in benchmark mode. When forward_only on, it forwards CNN.
Returns:
Dictionary containing training statistics (num_workers, num_steps,
average_wall_time, images_per_sec).
"""
if self.params.variable_update == 'distributed_all_reduce':
self.single_session = True
(image_producer_ops, enqueue_ops, fetches) = (
self._build_model_single_session())
else:
self.single_session = False
(image_producer_ops, enqueue_ops, fetches) = self._build_model()
fetches_list = nest.flatten(list(fetches.values()))
main_fetch_group = tf.group(*fetches_list)
execution_barrier = None
if (not self.single_session and self.job_name and
not self.params.cross_replica_sync):
execution_barrier = self.add_sync_queues_and_barrier(
'execution_barrier_', [])
global_step = tf.train.get_global_step()
with tf.device(self.global_step_device):
with tf.control_dependencies([main_fetch_group]):
fetches['inc_global_step'] = global_step.assign_add(1)
if ((not self.single_session) and self.job_name and
self.params.cross_replica_sync):
# Block all replicas until all replicas are ready for next step.
fetches['sync_queues'] = self.add_sync_queues_and_barrier(
'sync_queues_step_end_', [main_fetch_group])
local_var_init_op = tf.local_variables_initializer()
variable_mgr_init_ops = [local_var_init_op]
with tf.control_dependencies([local_var_init_op]):
variable_mgr_init_ops.extend(self.variable_mgr.get_post_init_ops())
local_var_init_op_group = tf.group(*variable_mgr_init_ops)
summary_op = tf.summary.merge_all()
is_chief = (not self.job_name or self.task_index == 0)
summary_writer = None
if (is_chief and self.params.summary_verbosity and self.params.train_dir and
self.params.save_summaries_steps > 0):
summary_writer = tf.summary.FileWriter(self.params.train_dir,
tf.get_default_graph())
# We want to start the benchmark timer right after a image_producer barrier
# and avoids undesired wating times on barriers.
if ((self.num_warmup_batches + len(enqueue_ops) - 1) %
self.batch_group_size) != 0:
self.num_warmup_batches = int(
math.ceil((self.num_warmup_batches + len(enqueue_ops) - 1.0) /
self.batch_group_size) * self.batch_group_size
- len(enqueue_ops) + 1)
log_fn('Round up warm up steps to %d to match batch_group_size' %
self.num_warmup_batches)
assert ((self.num_warmup_batches + len(enqueue_ops) - 1) %
self.batch_group_size) == 0
# We run the summaries in the same thread as the training operations by
# passing in None for summary_op to avoid a summary_thread being started.
# Running summaries and training operations in parallel could run out of
# GPU memory.
saver = tf.train.Saver(self.variable_mgr.savable_variables())
ready_for_local_init_op = None
if self.job_name and not self.single_session:
# In distributed mode, we don't want to run local_var_init_op_group until
# the global variables are initialized, because local_var_init_op_group
# may use global variables (such as in distributed replicated mode). We
# don't set this in non-distributed mode, because in non-distributed mode,
# local_var_init_op_group may itself initialize global variables (such as
# in replicated mode).
ready_for_local_init_op = tf.report_uninitialized_variables(
tf.global_variables())
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=self.params.train_dir,
ready_for_local_init_op=ready_for_local_init_op,
local_init_op=local_var_init_op_group,
saver=saver,
global_step=global_step,
summary_op=None,
save_model_secs=self.params.save_model_secs,
summary_writer=summary_writer)
step_train_times = []
start_standard_services = (self.params.summary_verbosity >= 1 or
self.dataset.queue_runner_required())
if self.job_name == 'controller':
master_target = self.worker_hosts[0]
else:
master_target = self.server.target if self.server else ''
with sv.managed_session(
master=master_target,
config=create_config_proto(self.params),
start_standard_services=start_standard_services) as sess:
image_producer = cnn_util.ImageProducer(sess, image_producer_ops,
self.batch_group_size)
image_producer.start()
for i in xrange(len(enqueue_ops)):
sess.run(enqueue_ops[:(i+1)])
image_producer.notify_image_consumption()
self.init_global_step, = sess.run([global_step])
if not self.single_session:
global_step_watcher = GlobalStepWatcher(
sess, global_step,
len(self.worker_hosts) * self.num_warmup_batches +
self.init_global_step,
len(self.worker_hosts) * (
self.num_warmup_batches + self.num_batches) - 1)
global_step_watcher.start()
else:
global_step_watcher = None
if self.graph_file is not None:
path, filename = os.path.split(self.graph_file)
as_text = filename.endswith('txt')
log_fn('Writing GraphDef as %s to %s' % (
'text' if as_text else 'binary', self.graph_file))
tf.train.write_graph(sess.graph_def, path, filename, as_text)
log_fn('Running warm up')
local_step = -1 * self.num_warmup_batches
if self.single_session or (self.params.cross_replica_sync and
self.params.job_name):
# In cross-replica sync mode, all workers must run the same number of
# local steps, or else the workers running the extra step will block.
done_fn = lambda: local_step == self.num_batches
else:
done_fn = global_step_watcher.done
loop_start_time = time.time()
while not done_fn():
if local_step == 0:
log_fn('Done warm up')
if execution_barrier:
log_fn('Waiting for other replicas to finish warm up')
assert global_step_watcher.start_time == 0
sess.run([execution_barrier])
header_str = 'Step\tImg/sec\tloss'
if self.params.print_training_accuracy or self.params.forward_only:
header_str += '\ttop_1_accuracy\ttop_5_accuracy'
log_fn(header_str)
assert len(step_train_times) == self.num_warmup_batches
# reset times to ignore warm up batch
step_train_times = []
loop_start_time = time.time()
if (summary_writer and
(local_step + 1) % self.params.save_summaries_steps == 0):
fetch_summary = summary_op
else:
fetch_summary = None
summary_str = benchmark_one_step(
sess, fetches, local_step,
self.batch_size * (len(self.worker_hosts)
if self.single_session else 1),
step_train_times, self.trace_filename, image_producer, self.params,
fetch_summary)
if summary_str is not None and is_chief:
sv.summary_computed(sess, summary_str)
local_step += 1
loop_end_time = time.time()
# Waits for the global step to be done, regardless of done_fn.
if global_step_watcher:
while not global_step_watcher.done():
time.sleep(.25)
if self.single_session:
num_workers = len(self.worker_hosts)
num_steps = local_step
elapsed_time = loop_end_time - loop_start_time
else:
num_workers = 1
num_steps = global_step_watcher.num_steps()
elapsed_time = global_step_watcher.elapsed_time()
average_wall_time = elapsed_time / num_steps if num_steps > 0 else 0
images_per_sec = ((num_workers * self.batch_size) /
average_wall_time if average_wall_time > 0 else 0)
log_fn('-' * 64)
log_fn('total images/sec: %.2f' % images_per_sec)
log_fn('-' * 64)
image_producer.done()
if is_chief:
store_benchmarks({'total_images_per_sec': images_per_sec}, self.params)
# Save the model checkpoint.
if self.params.train_dir is not None and is_chief:
checkpoint_path = os.path.join(self.params.train_dir, 'model.ckpt')
if not gfile.Exists(self.params.train_dir):
gfile.MakeDirs(self.params.train_dir)
sv.saver.save(sess, checkpoint_path, global_step)
if execution_barrier:
# Wait for other workers to reach the end, so this worker doesn't
# go away underneath them.
sess.run([execution_barrier])
sv.stop()
return {
'num_workers': num_workers,
'num_steps': num_steps,
'average_wall_time': average_wall_time,
'images_per_sec': images_per_sec
}