def get_distribution_strategy()

in example_zoo/tensorflow/models/mnist/official/utils/misc/distribution_utils.py [0:0]


def get_distribution_strategy(num_gpus,
                              all_reduce_alg=None,
                              turn_off_distribution_strategy=False):
  """Return a DistributionStrategy for running the model.

  Args:
    num_gpus: Number of GPUs to run this model.
    all_reduce_alg: Specify which algorithm to use when performing all-reduce.
      See tf.contrib.distribute.AllReduceCrossDeviceOps for available
      algorithms. If None, DistributionStrategy will choose based on device
      topology.
    turn_off_distribution_strategy: when set to True, do not use any
      distribution strategy. Note that when it is True, and num_gpus is
      larger than 1, it will raise a ValueError.

  Returns:
    tf.contrib.distribute.DistibutionStrategy object.
  Raises:
    ValueError: if turn_off_distribution_strategy is True and num_gpus is
    larger than 1
  """
  if num_gpus == 0:
    if turn_off_distribution_strategy:
      return None
    else:
      return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0")
  elif num_gpus == 1:
    if turn_off_distribution_strategy:
      return None
    else:
      return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0")
  elif turn_off_distribution_strategy:
    raise ValueError("When {} GPUs are specified, "
                     "turn_off_distribution_strategy flag cannot be set to"
                     "True.".format(num_gpus))
  else:  # num_gpus > 1 and not turn_off_distribution_strategy
    devices = ["device:GPU:%d" % i for i in range(num_gpus)]
    if all_reduce_alg:
      return tf.distribute.MirroredStrategy(
          devices=devices,
          cross_device_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(
              all_reduce_alg, num_packs=2))
    else:
      return tf.distribute.MirroredStrategy(devices=devices)