def get_tpu_cluster_resolver_fn()

in src/python/tensorflow_cloud/core/preprocess.py [0:0]


def get_tpu_cluster_resolver_fn():
    """Returns the fn required for runnning custom container on cloud TPUs.

    This function is added to the user code in the custom container before
    running it on the cloud. With this function, we wait for the TPU to be
    provisioned before calling TpuClusterResolver.

    https://cloud.devsite.corp.google.com/ai-platform/training/docs/
    using-tpus#custom-containers
    """
    return [
        "import json\n",
        "import logging\n",
        "import time\n",
        "logger = logging.getLogger(__name__)\n",
        "logging.basicConfig(level=logging.INFO)\n",
        "def wait_for_tpu_cluster_resolver_ready():\n",
        "  tpu_config_env = os.environ.get('TPU_CONFIG')\n",
        "  if not tpu_config_env:\n",
        "    logging.info('Missing TPU_CONFIG, use CPU/GPU for training.')\n",
        "    return None\n",
        "  tpu_node = json.loads(tpu_config_env)\n",
        "  logging.info('Waiting for TPU to be ready: %s.', tpu_node)\n",
        "  num_retries = 40\n",
        "  for i in range(num_retries):\n",
        "    try:\n",
        "      tpu_cluster_resolver = (\n",
        "          tf.distribute.cluster_resolver.TPUClusterResolver(\n",
        "              tpu=[tpu_node['tpu_node_name']],\n",
        "              zone=tpu_node['zone'],\n",
        "              project=tpu_node['project'],\n",
        "              job_name='worker'))\n",
        "      tpu_cluster_resolver_dict = "
        "tpu_cluster_resolver.cluster_spec().as_dict()\n",
        "      if 'worker' in tpu_cluster_resolver_dict:\n",
        ("        logging.info('Found TPU worker: %s', "
         "tpu_cluster_resolver_dict)\n"),
        "        return tpu_cluster_resolver\n",
        "    except Exception as e:\n",
        "      if i < num_retries - 1:\n",
        ("        logging.info('Still waiting for provisioning of TPU VM "
         "instance.')\n"),
        "      else:\n",
        "        # Preserves the traceback.\n",
        "        raise RuntimeError('Failed to schedule TPU: {}'.format(e))\n",
        "    time.sleep(10)\n",
        "  raise RuntimeError('Failed to schedule TPU.')\n",
    ]