def _benchmark_cnn()

in tensorflow_benchmark/tf_cnn_benchmarks/benchmark_cnn.py [0:0]


  def _benchmark_cnn(self):
    """Run cnn in benchmark mode. When forward_only on, it forwards CNN.

    Returns:
      Dictionary containing training statistics (num_workers, num_steps,
      average_wall_time, images_per_sec).
    """
    if self.params.variable_update == 'distributed_all_reduce':
      self.single_session = True
      (image_producer_ops, enqueue_ops, fetches) = (
          self._build_model_single_session())
    else:
      self.single_session = False
      (image_producer_ops, enqueue_ops, fetches) = self._build_model()
    fetches_list = nest.flatten(list(fetches.values()))
    main_fetch_group = tf.group(*fetches_list)
    execution_barrier = None
    if (not self.single_session and self.job_name and
        not self.params.cross_replica_sync):
      execution_barrier = self.add_sync_queues_and_barrier(
          'execution_barrier_', [])

    global_step = tf.train.get_global_step()
    with tf.device(self.global_step_device):
      with tf.control_dependencies([main_fetch_group]):
        fetches['inc_global_step'] = global_step.assign_add(1)

    if ((not self.single_session) and self.job_name and
        self.params.cross_replica_sync):
      # Block all replicas until all replicas are ready for next step.
      fetches['sync_queues'] = self.add_sync_queues_and_barrier(
          'sync_queues_step_end_', [main_fetch_group])

    local_var_init_op = tf.local_variables_initializer()
    variable_mgr_init_ops = [local_var_init_op]
    with tf.control_dependencies([local_var_init_op]):
      variable_mgr_init_ops.extend(self.variable_mgr.get_post_init_ops())
    local_var_init_op_group = tf.group(*variable_mgr_init_ops)

    summary_op = tf.summary.merge_all()
    is_chief = (not self.job_name or self.task_index == 0)
    summary_writer = None
    if (is_chief and self.params.summary_verbosity and self.params.train_dir and
        self.params.save_summaries_steps > 0):
      summary_writer = tf.summary.FileWriter(self.params.train_dir,
                                             tf.get_default_graph())

    # We want to start the benchmark timer right after a image_producer barrier
    # and avoids undesired wating times on barriers.
    if ((self.num_warmup_batches + len(enqueue_ops) - 1) %
        self.batch_group_size) != 0:
      self.num_warmup_batches = int(
          math.ceil((self.num_warmup_batches + len(enqueue_ops) - 1.0) /
                    self.batch_group_size) * self.batch_group_size
          - len(enqueue_ops) + 1)
      log_fn('Round up warm up steps to %d to match batch_group_size' %
             self.num_warmup_batches)
      assert ((self.num_warmup_batches + len(enqueue_ops) - 1) %
              self.batch_group_size) == 0
    # We run the summaries in the same thread as the training operations by
    # passing in None for summary_op to avoid a summary_thread being started.
    # Running summaries and training operations in parallel could run out of
    # GPU memory.
    saver = tf.train.Saver(self.variable_mgr.savable_variables())
    ready_for_local_init_op = None
    if self.job_name and not self.single_session:
      # In distributed mode, we don't want to run local_var_init_op_group until
      # the global variables are initialized, because local_var_init_op_group
      # may use global variables (such as in distributed replicated mode). We
      # don't set this in non-distributed mode, because in non-distributed mode,
      # local_var_init_op_group may itself initialize global variables (such as
      # in replicated mode).
      ready_for_local_init_op = tf.report_uninitialized_variables(
          tf.global_variables())
    sv = tf.train.Supervisor(
        is_chief=is_chief,
        logdir=self.params.train_dir,
        ready_for_local_init_op=ready_for_local_init_op,
        local_init_op=local_var_init_op_group,
        saver=saver,
        global_step=global_step,
        summary_op=None,
        save_model_secs=self.params.save_model_secs,
        summary_writer=summary_writer)

    step_train_times = []
    start_standard_services = (self.params.summary_verbosity >= 1 or
                               self.dataset.queue_runner_required())
    if self.job_name == 'controller':
      master_target = self.worker_hosts[0]
    else:
      master_target = self.server.target if self.server else ''
    with sv.managed_session(
        master=master_target,
        config=create_config_proto(self.params),
        start_standard_services=start_standard_services) as sess:
      image_producer = cnn_util.ImageProducer(sess, image_producer_ops,
                                              self.batch_group_size)
      image_producer.start()
      for i in xrange(len(enqueue_ops)):
        sess.run(enqueue_ops[:(i+1)])
        image_producer.notify_image_consumption()
      self.init_global_step, = sess.run([global_step])
      if not self.single_session:
        global_step_watcher = GlobalStepWatcher(
            sess, global_step,
            len(self.worker_hosts) * self.num_warmup_batches +
            self.init_global_step,
            len(self.worker_hosts) * (
                self.num_warmup_batches + self.num_batches) - 1)
        global_step_watcher.start()
      else:
        global_step_watcher = None

      if self.graph_file is not None:
        path, filename = os.path.split(self.graph_file)
        as_text = filename.endswith('txt')
        log_fn('Writing GraphDef as %s to %s' % (
            'text' if as_text else 'binary', self.graph_file))
        tf.train.write_graph(sess.graph_def, path, filename, as_text)

      log_fn('Running warm up')
      local_step = -1 * self.num_warmup_batches

      if self.single_session or (self.params.cross_replica_sync and
                                 self.params.job_name):
        # In cross-replica sync mode, all workers must run the same number of
        # local steps, or else the workers running the extra step will block.
        done_fn = lambda: local_step == self.num_batches
      else:
        done_fn = global_step_watcher.done
      loop_start_time = time.time()
      while not done_fn():
        if local_step == 0:
          log_fn('Done warm up')
          if execution_barrier:
            log_fn('Waiting for other replicas to finish warm up')
            assert global_step_watcher.start_time == 0
            sess.run([execution_barrier])

          header_str = 'Step\tImg/sec\tloss'
          if self.params.print_training_accuracy or self.params.forward_only:
            header_str += '\ttop_1_accuracy\ttop_5_accuracy'
          log_fn(header_str)
          assert len(step_train_times) == self.num_warmup_batches
          # reset times to ignore warm up batch
          step_train_times = []
          loop_start_time = time.time()
        if (summary_writer and
            (local_step + 1) % self.params.save_summaries_steps == 0):
          fetch_summary = summary_op
        else:
          fetch_summary = None
        summary_str = benchmark_one_step(
            sess, fetches, local_step,
            self.batch_size * (len(self.worker_hosts)
                               if self.single_session else 1),
            step_train_times, self.trace_filename, image_producer, self.params,
            fetch_summary)
        if summary_str is not None and is_chief:
          sv.summary_computed(sess, summary_str)
        local_step += 1
      loop_end_time = time.time()
      # Waits for the global step to be done, regardless of done_fn.
      if global_step_watcher:
        while not global_step_watcher.done():
          time.sleep(.25)
      if self.single_session:
        num_workers = len(self.worker_hosts)
        num_steps = local_step
        elapsed_time = loop_end_time - loop_start_time
      else:
        num_workers = 1
        num_steps = global_step_watcher.num_steps()
        elapsed_time = global_step_watcher.elapsed_time()

      average_wall_time = elapsed_time / num_steps if num_steps > 0 else 0
      images_per_sec = ((num_workers * self.batch_size) /
                        average_wall_time if average_wall_time > 0 else 0)

      log_fn('-' * 64)
      log_fn('total images/sec: %.2f' % images_per_sec)
      log_fn('-' * 64)
      image_producer.done()
      if is_chief:
        store_benchmarks({'total_images_per_sec': images_per_sec}, self.params)
      # Save the model checkpoint.
      if self.params.train_dir is not None and is_chief:
        checkpoint_path = os.path.join(self.params.train_dir, 'model.ckpt')
        if not gfile.Exists(self.params.train_dir):
          gfile.MakeDirs(self.params.train_dir)
        sv.saver.save(sess, checkpoint_path, global_step)

      if execution_barrier:
        # Wait for other workers to reach the end, so this worker doesn't
        # go away underneath them.
        sess.run([execution_barrier])
    sv.stop()
    return {
        'num_workers': num_workers,
        'num_steps': num_steps,
        'average_wall_time': average_wall_time,
        'images_per_sec': images_per_sec
    }