def __init__()

in tensorflow_benchmark/tf_cnn_benchmarks/benchmark_cnn.py [0:0]


  def __init__(self, params):
    """Initialize BenchmarkCNN.

    Args:
      params: Params tuple, typically created by make_params or
              make_params_from_flags.
    Raises:
      ValueError: Unsupported params settings.
    """
    self.params = params
    self.dataset = datasets.create_dataset(self.params.data_dir,
                                           self.params.data_name)
    self.model = model_config.get_model_config(self.params.model, self.dataset)
    self.trace_filename = self.params.trace_file
    self.data_format = self.params.data_format
    self.num_batches = self.params.num_batches
    autotune_threshold = self.params.autotune_threshold if (
        self.params.autotune_threshold) else 1
    min_autotune_warmup = 5 * autotune_threshold * autotune_threshold
    self.num_warmup_batches = self.params.num_warmup_batches if (
        self.params.num_warmup_batches is not None) else max(
            10, min_autotune_warmup)
    self.graph_file = self.params.graph_file
    self.resize_method = self.params.resize_method
    self.sync_queue_counter = 0
    self.num_gpus = self.params.num_gpus
    if self.params.gpu_indices:
      self.gpu_indices = [int(x) for x in self.params.gpu_indices.split(',')]
    else:
      self.gpu_indices = [x for x in range(self.num_gpus)]
    self.use_synthetic_gpu_images = self.dataset.use_synthetic_gpu_images()

    if (self.params.device == 'cpu' and self.params.data_format == 'NCHW' and
        not self.params.mkl):
      raise ValueError('device=cpu requires that data_format=NHWC')

    if self.params.use_tf_layers and self.params.use_fp16:
      raise ValueError('if use_fp16=true, use_tf_layers must be false.')

    if ((self.params.num_epochs_per_decay or
         self.params.learning_rate_decay_factor) and
        not (self.params.learning_rate and self.params.num_epochs_per_decay and
             self.params.learning_rate_decay_factor)):
      raise ValueError('If one of num_epochs_per_decay or '
                       'learning_rate_decay_factor is set, both must be set'
                       'and learning_rate must be set')
    if (self.params.minimum_learning_rate and
        not (self.params.learning_rate and self.params.num_epochs_per_decay and
             self.params.learning_rate_decay_factor)):
      raise ValueError('minimum_learning_rate requires learning_rate,'
                       'num_epochs_per_decay, and '
                       'learning_rate_decay_factor to be set')

    if (self.params.use_fp16 and self.params.fp16_vars and
        'replicated' in self.params.variable_update and
        'nccl' in self.params.all_reduce_spec):
      raise ValueError('fp16 variables are not supported with NCCL')

    # Use the batch size from the command line if specified, otherwise use the
    # model's default batch size.  Scale the benchmark's batch size by the
    # number of GPUs.
    if self.params.batch_size > 0:
      self.model.set_batch_size(self.params.batch_size)
    self.batch_size = self.model.get_batch_size() * self.num_gpus
    self.batch_group_size = self.params.batch_group_size

    if self.params.use_fp16:
      self.loss_scale = (self.params.fp16_loss_scale or
                         self.model.get_fp16_loss_scale())
    else:
      self.loss_scale = 1.

    self.job_name = self.params.job_name  # "" for local training
    self.ps_hosts = self.params.ps_hosts.split(',')
    self.worker_hosts = self.params.worker_hosts.split(',')
    self.controller_host = self.params.controller_host

    if len(self.worker_hosts) > 1 and self.params.all_reduce_spec == 'nccl':
      raise ValueError('--all_reduce_spec=nccl is invalid in a '
                       'multi-worker job')

    # PS server is used for distributed jobs not using all-reduce.
    use_ps_server = self.job_name and (
        self.params.variable_update != 'distributed_all_reduce')
    # controller is used for distributed_all_reduce with > 1 worker.
    use_controller = (self.params.variable_update == 'distributed_all_reduce'
                      and self.job_name)
    if use_controller and not self.controller_host:
      raise ValueError('When variable_update==distributed_all_reduce '
                       'controller_host must also be specified.')

    self.local_parameter_device_flag = self.params.local_parameter_device
    if self.job_name:
      self.task_index = self.params.task_index
      if use_controller:
        assert not use_ps_server
        self.cluster = tf.train.ClusterSpec(
            {'controller': [self.controller_host],
             'worker': self.worker_hosts})
      else:
        assert use_ps_server
        self.cluster = tf.train.ClusterSpec({'ps': self.ps_hosts,
                                             'worker': self.worker_hosts})

      self.server = None
      if self.job_name != 'controller':
        if not self.server:
          self.server = tf.train.Server(self.cluster, job_name=self.job_name,
                                        task_index=self.task_index,
                                        config=create_config_proto(self.params),
                                        protocol=self.params.server_protocol)

      worker_prefix = '/job:worker/task:%s' % self.task_index
      if use_ps_server:
        self.param_server_device = tf.train.replica_device_setter(
            worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
        # This device on which the queues for managing synchronization between
        # servers should be stored.
        num_ps = len(self.ps_hosts)
        self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i
                                   for i in range(num_ps)]
      else:
        self.sync_queue_devices = ['/job:worker/task:0/cpu:0']
    else:
      self.task_index = 0
      self.cluster = None
      self.server = None
      worker_prefix = ''
      self.param_server_device = '/%s:0' % self.params.local_parameter_device
      self.sync_queue_devices = [self.param_server_device]

    # Device to use for ops that need to always run on the local worker's CPU.
    self.cpu_device = '%s/cpu:0' % worker_prefix

    # Device to use for ops that need to always run on the local worker's
    # compute device, and never on a parameter server device.
    self.raw_devices = [
        '%s/%s:%i' % (worker_prefix, self.params.device, i)
        for i in xrange(self.num_gpus)
    ]

    if (self.params.staged_vars and
        self.params.variable_update != 'parameter_server'):
      raise ValueError('staged_vars for now is only supported with '
                       'variable_update=parameter_server')

    if self.params.variable_update == 'parameter_server':
      if self.job_name:
        if not self.params.staged_vars:
          self.variable_mgr = variable_mgr.VariableMgrDistributedFetchFromPS(
              self)
        else:
          self.variable_mgr = (
              variable_mgr.VariableMgrDistributedFetchFromStagedPS(self))
      else:
        if not self.params.staged_vars:
          self.variable_mgr = variable_mgr.VariableMgrLocalFetchFromPS(self)
        else:
          self.variable_mgr = variable_mgr.VariableMgrLocalFetchFromStagedPS(
              self)
    elif self.params.variable_update == 'replicated':
      if self.job_name:
        raise ValueError('Invalid variable_update in distributed mode: %s' %
                         self.params.variable_update)
      self.variable_mgr = variable_mgr.VariableMgrLocalReplicated(
          self, self.params.all_reduce_spec)
    elif self.params.variable_update == 'distributed_all_reduce':
      assert self.params.cross_replica_sync
      self.variable_mgr = variable_mgr.VariableMgrDistributedAllReduce(
          self, self.params.all_reduce_spec,
          'worker' if len(self.worker_hosts) > 1 else 'localhost',
          len(self.worker_hosts))
    elif self.params.variable_update == 'distributed_replicated':
      assert self.params.cross_replica_sync
      if not self.job_name:
        raise ValueError('Invalid variable_update in local mode: %s' %
                         self.params.variable_update)
      self.variable_mgr = variable_mgr.VariableMgrDistributedReplicated(self)
    elif self.params.variable_update == 'independent':
      if self.job_name:
        raise ValueError('Invalid variable_update in distributed mode: %s' %
                         self.params.variable_update)
      self.variable_mgr = variable_mgr.VariableMgrIndependent(self)
    else:
      raise ValueError(
          'Invalid variable_update: %s' % self.params.variable_update)

    # Device to use for running on the local worker's compute device, but
    # with variables assigned to parameter server devices.
    self.devices = self.variable_mgr.get_devices()
    if self.job_name:
      if use_ps_server:
        self.global_step_device = self.param_server_device
      else:
        self.global_step_device = '/job:worker/task:0/cpu:0'
    else:
      self.global_step_device = self.cpu_device

    self.image_preprocessor = self.get_image_preprocessor()
    self.init_global_step = 0