def get_distribution_strategy_str()

in src/python/tensorflow_cloud/core/experimental/models.py [0:0]


def get_distribution_strategy_str(run_kwargs):
  """Gets the name of a distribution strategy based on cloud run config."""
  if ('worker_count' in run_kwargs and run_kwargs['worker_count'] > 0):
    if ('worker_config' in run_kwargs and
        machine_config.is_tpu_config(run_kwargs['worker_config'])):
      return 'tpu'
    else:
      return 'multi_mirror'
  elif ('chief_config' in run_kwargs and
        run_kwargs['chief_config'].accelerator_count > 1):
    return 'mirror'
  else:
    return 'one_device'