07_training/07d_distribute.ipynb (517 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 72 }, "id": "hiQ6zAoYhyaA", "outputId": "0acee878-1207-42c3-9bee-a594acd44365" }, "outputs": [], "source": [ "from IPython.display import Markdown as md\n", "\n", "### change to reflect your notebook\n", "_nb_loc = \"07_training/07d_distribute.ipynb\"\n", "_nb_title = \"Distributed Training\"\n", "\n", "### no need to change any of this\n", "_nb_safeloc = _nb_loc.replace('/', '%2F')\n", "md(\"\"\"\n", "<table class=\"tfo-notebook-buttons\" align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://console.cloud.google.com/ai-platform/notebooks/deploy-notebook?name={1}&url=https%3A%2F%2Fgithub.com%2FGoogleCloudPlatform%2Fpractical-ml-vision-book%2Fblob%2Fmaster%2F{2}&download_url=https%3A%2F%2Fgithub.com%2FGoogleCloudPlatform%2Fpractical-ml-vision-book%2Fraw%2Fmaster%2F{2}\">\n", " <img src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/practical-ml-vision-book/master/logo-cloud.png\"/> Run in AI Platform Notebook</a>\n", " </td>\n", " </td>\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/GoogleCloudPlatform/practical-ml-vision-book/blob/master/{0}\">\n", " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", " <td>\n", " <a target=\"_blank\" href=\"https://github.com/GoogleCloudPlatform/practical-ml-vision-book/blob/master/{0}\">\n", " <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n", " </td>\n", " <td>\n", " <a href=\"https://raw.githubusercontent.com/GoogleCloudPlatform/practical-ml-vision-book/master/{0}\">\n", " <img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n", " </td>\n", "</table>\n", "\"\"\".format(_nb_loc, _nb_title, _nb_safeloc))" ] }, { "cell_type": "markdown", "metadata": { "id": "a8HQYsAtC0Fv" }, "source": [ "# Distributed Training\n", "\n", "In this notebook, we show how to train the model on multiple GPUs." ] }, { "cell_type": "markdown", "metadata": { "id": "5UOm2etrwYCs" }, "source": [ "## Enable GPU and set up helper functions\n", "\n", "This notebook and pretty much every other notebook in this repository\n", "will run faster if you are using a GPU.\n", "On Colab:\n", "- Navigate to Edit→Notebook Settings\n", "- Select GPU from the Hardware Accelerator drop-down\n", "\n", "On Cloud AI Platform Notebooks:\n", "- Navigate to https://console.cloud.google.com/ai-platform/notebooks\n", "- Create an instance with a GPU or select your instance and add a GPU\n", "\n", "Next, we'll confirm that we can connect to the GPU with tensorflow:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ugGJcxKAwhc2", "outputId": "8e946159-46cf-4aba-f53e-622e9ea8adee" }, "outputs": [], "source": [ "import tensorflow as tf\n", "print('TensorFlow version' + tf.version.VERSION)\n", "print('Built with GPU support? ' + ('Yes!' if tf.test.is_built_with_cuda() else 'Noooo!'))\n", "print('There are {} GPUs'.format(len(tf.config.experimental.list_physical_devices(\"GPU\"))))\n", "device_name = tf.test.gpu_device_name()\n", "if device_name != '/device:GPU:0':\n", " raise SystemError('GPU device not found')\n", "print('Found GPU at: {}'.format(device_name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Strategy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pylab as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import os, shutil\n", "from tensorflow.data.experimental import AUTOTUNE\n", "\n", "#PATTERN_SUFFIX, NUM_EPOCHS = '-0000[01]-*', 3 # small\n", "PATTERN_SUFFIX, NUM_EPOCHS = '-*', 20 # full\n", "\n", "IMG_HEIGHT = 448 # note *twice* what we used to have\n", "IMG_WIDTH = 448\n", "IMG_CHANNELS = 3\n", "CLASS_NAMES = 'daisy dandelion roses sunflowers tulips'.split()\n", "\n", "CHECKPOINT_DIR='./chkpts'\n", "OUTDIR = './export'\n", "\n", "def setup_trainer():\n", " shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)\n", " os.mkdir(CHECKPOINT_DIR)\n", " shutil.rmtree(OUTDIR, ignore_errors=True)\n", " os.mkdir(OUTDIR)\n", " # Load compressed models from tensorflow_hub\n", " os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' \n", "\n", "def create_strategy(mode='gpus_one_machine'):\n", " \"\"\"\n", " mode has be to be one of the following:\n", " * gpus_one_machine\n", " * gpus_multiple_machines\n", " * tpu_colab\n", " * tpu_caip\n", " * the actual name of the cloud_tpu\n", " \"\"\"\n", " if mode == 'gpus_one_machine':\n", " print('Using {} GPUs'.format(len(tf.config.experimental.list_physical_devices(\"GPU\"))))\n", " return tf.distribute.MirroredStrategy()\n", " if mode == 'gpus_multiple_machines':\n", " print(\"Using TFCONFIG=\", os.environ[\"TF_CONFIG\"])\n", " return tf.distribute.MultiWorkerMirroredStrategy()\n", " \n", " # treat as tpu\n", " if mode == 'tpu_colab':\n", " tpu_name = 'grpc://' + os.environ['COLAB_TPU_ADDR']\n", " elif mode == 'tpu_caip':\n", " tpu_name = None\n", " else:\n", " tpu_name = mode\n", " print(\"Using TPU: \", tpu_name)\n", " resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)\n", " tf.config.experimental_connect_to_cluster(resolver)\n", " # TPUs wipe out memory, so this has to be at very start of program\n", " tf.tpu.experimental.initialize_tpu_system(resolver)\n", " print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n", " return tf.distribute.TPUStrategy(resolver)\n", "\n", "setup_trainer()\n", "strategy = create_strategy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training code\n", "\n", "Streamlined from previous notebooks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def training_plot(metrics, history):\n", " f, ax = plt.subplots(1, len(metrics), figsize=(5*len(metrics), 5))\n", " for idx, metric in enumerate(metrics):\n", " ax[idx].plot(history.history[metric], ls='dashed')\n", " ax[idx].set_xlabel(\"Epochs\")\n", " ax[idx].set_ylabel(metric)\n", " ax[idx].plot(history.history['val_' + metric]);\n", " ax[idx].legend([metric, 'val_' + metric])\n", " \n", "class _Preprocessor: \n", " def __init__(self):\n", " # nothing to initialize\n", " pass\n", " \n", " def read_from_tfr(self, proto):\n", " feature_description = {\n", " 'image': tf.io.VarLenFeature(tf.float32),\n", " 'shape': tf.io.VarLenFeature(tf.int64),\n", " 'label': tf.io.FixedLenFeature([], tf.string, default_value=''),\n", " 'label_int': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n", " }\n", " rec = tf.io.parse_single_example(\n", " proto, feature_description\n", " )\n", " shape = tf.sparse.to_dense(rec['shape'])\n", " img = tf.reshape(tf.sparse.to_dense(rec['image']), shape)\n", " label_int = rec['label_int']\n", " return img, label_int\n", " \n", " def read_from_jpegfile(self, filename):\n", " # same code as in 05_create_dataset/jpeg_to_tfrecord.py\n", " img = tf.io.read_file(filename)\n", " img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)\n", " img = tf.image.convert_image_dtype(img, tf.float32)\n", " return img\n", " \n", " def preprocess(self, img):\n", " return tf.image.resize_with_pad(img, IMG_HEIGHT, IMG_WIDTH)\n", "\n", "# most efficient way to read the data\n", "# as determined in 07a_ingest.ipynb\n", "# splits the files into two halves and interleaves datasets\n", "def create_preproc_dataset(pattern):\n", " \"\"\"\n", " Does interleaving, parallel calls, prefetch, batching\n", " Caching is not a good idea on large datasets.\n", " \"\"\"\n", " preproc = _Preprocessor()\n", " files = [filename for filename \n", " in tf.random.shuffle(tf.io.gfile.glob(pattern))]\n", " if len(files) > 1:\n", " print(\"Interleaving the reading of {} files.\".format(len(files)))\n", " def _create_half_ds(x):\n", " if x == 0:\n", " half = files[:(len(files)//2)]\n", " else:\n", " half = files[(len(files)//2):]\n", " return tf.data.TFRecordDataset(half,\n", " compression_type='GZIP')\n", " trainds = tf.data.Dataset.range(2).interleave(\n", " _create_half_ds, num_parallel_calls=AUTOTUNE)\n", " else:\n", " trainds = tf.data.TFRecordDataset(files,\n", " compression_type='GZIP')\n", " def _preproc_img_label(img, label):\n", " return (preproc.preprocess(img), label)\n", " \n", " trainds = (trainds\n", " .map(preproc.read_from_tfr, num_parallel_calls=AUTOTUNE)\n", " .map(_preproc_img_label, num_parallel_calls=AUTOTUNE)\n", " .shuffle(200)\n", " .prefetch(AUTOTUNE)\n", " )\n", " return trainds\n", "\n", "def create_preproc_image(filename):\n", " preproc = _Preprocessor()\n", " img = preproc.read_from_jpegfile(filename)\n", " return preproc.preprocess(img)\n", "\n", "class RandomColorDistortion(tf.keras.layers.Layer):\n", " def __init__(self, contrast_range=[0.5, 1.5], \n", " brightness_delta=[-0.2, 0.2], **kwargs):\n", " super(RandomColorDistortion, self).__init__(**kwargs)\n", " self.contrast_range = contrast_range\n", " self.brightness_delta = brightness_delta\n", " \n", " def call(self, images, training=None):\n", " if not training:\n", " return images\n", " \n", " contrast = np.random.uniform(\n", " self.contrast_range[0], self.contrast_range[1])\n", " brightness = np.random.uniform(\n", " self.brightness_delta[0], self.brightness_delta[1])\n", " \n", " images = tf.image.adjust_contrast(images, contrast)\n", " images = tf.image.adjust_brightness(images, brightness)\n", " images = tf.clip_by_value(images, 0, 1)\n", " return images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_model(l1, l2, num_hidden):\n", " regularizer = tf.keras.regularizers.l1_l2(l1, l2)\n", " \n", " layers = [\n", " tf.keras.layers.experimental.preprocessing.RandomCrop(\n", " height=IMG_HEIGHT//2, width=IMG_WIDTH//2,\n", " input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS),\n", " name='random/center_crop'\n", " ),\n", " tf.keras.layers.experimental.preprocessing.RandomFlip(\n", " mode='horizontal',\n", " name='random_lr_flip/none'\n", " ),\n", " RandomColorDistortion(name='random_contrast_brightness/none'),\n", " hub.KerasLayer(\n", " \"https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4\", \n", " trainable=False,\n", " name='mobilenet_embedding'),\n", " tf.keras.layers.Dense(num_hidden,\n", " kernel_regularizer=regularizer, \n", " activation=tf.keras.activations.relu,\n", " name='dense_hidden'),\n", " tf.keras.layers.Dense(len(CLASS_NAMES), \n", " kernel_regularizer=regularizer,\n", " activation='softmax',\n", " name='flower_prob')\n", " ]\n", "\n", " # checkpoint and early stopping callbacks\n", " model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n", " filepath='./chkpts',\n", " monitor='val_accuracy', mode='max',\n", " save_best_only=True)\n", " early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n", " monitor='val_accuracy', mode='max',\n", " patience=2)\n", " \n", " # create model\n", " return tf.keras.Sequential(layers, name='flower_classification')\n", "\n", "def train_and_evaluate(strategy,\n", " batch_size = 32,\n", " lrate = 0.001,\n", " l1 = 0.,\n", " l2 = 0.,\n", " num_hidden = 16):\n", " \n", " train_dataset = create_preproc_dataset(\n", " 'gs://practical-ml-vision-book-data/flowers_tfr/train' + PATTERN_SUFFIX\n", " ).batch(batch_size)\n", " eval_dataset = create_preproc_dataset(\n", " 'gs://practical-ml-vision-book-data/flowers_tfr/valid' + PATTERN_SUFFIX\n", " ).batch(batch_size)\n", "\n", " # checkpoint and early stopping callbacks\n", " model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n", " filepath='./chkpts',\n", " monitor='val_accuracy', mode='max',\n", " save_best_only=True)\n", " early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n", " monitor='val_accuracy', mode='max',\n", " patience=2)\n", " \n", " # model training\n", " with strategy.scope():\n", " model = create_model(l1, l2, num_hidden)\n", " model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lrate),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=False),\n", " metrics=['accuracy']\n", " )\n", " print(model.summary())\n", " history = model.fit(train_dataset, \n", " validation_data=eval_dataset,\n", " epochs=NUM_EPOCHS,\n", " callbacks=[model_checkpoint_cb, early_stopping_cb]\n", " )\n", " training_plot(['loss', 'accuracy'], history)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 929 }, "id": "jlxxxpeaT6ea", "outputId": "ad4f09e8-bc33-4c92-dc47-5fc73ec12f9c" }, "outputs": [], "source": [ "model = train_and_evaluate(strategy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@tf.function(input_signature=[tf.TensorSpec([None,], dtype=tf.string)])\n", "def predict_flower_type(filenames):\n", " input_images = tf.map_fn(\n", " create_preproc_image,\n", " filenames,\n", " fn_output_signature=tf.float32\n", " )\n", " batch_pred = model(input_images) # same as model.predict()\n", " top_prob = tf.math.reduce_max(batch_pred, axis=[1])\n", " pred_label_index = tf.math.argmax(batch_pred, axis=1)\n", " pred_label = tf.gather(tf.convert_to_tensor(CLASS_NAMES), pred_label_index)\n", " return {\n", " 'probability': top_prob,\n", " 'flower_type_int': pred_label_index,\n", " 'flower_type_str': pred_label\n", " }\n", "\n", "model.save(os.path.join(OUTDIR, 'flowers_model'),\n", " signatures={\n", " 'serving_default': predict_flower_type\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Try it out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "serving_fn = tf.keras.models.load_model(\n", " os.path.join(OUTDIR, 'flowers_model')\n", ").signatures['serving_default']\n", "filenames = [\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/dandelion/9818247_e2eac18894.jpg',\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/dandelion/9853885425_4a82356f1d_m.jpg',\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/daisy/9158041313_7a6a102f7a_n.jpg',\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/daisy/9299302012_958c70564c_n.jpg',\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/tulips/8733586143_3139db6e9e_n.jpg',\n", " 'gs://practical-ml-vision-book-data/flowers_5_jpeg/flower_photos/tulips/8713397358_0505cc0176_n.jpg'\n", "]\n", "pred = serving_fn(tf.convert_to_tensor(filenames))\n", "print(pred)\n", "print('******')\n", "print(pred['flower_type_str'].numpy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f, ax = plt.subplots(1, 6, figsize=(15,15))\n", "for idx, (filename, prob, pred_label) in enumerate(\n", " zip(filenames, pred['probability'].numpy(), pred['flower_type_str'].numpy())):\n", " \n", " img = tf.io.read_file(filename)\n", " img = tf.image.decode_jpeg(img, channels=3)\n", " ax[idx].imshow((img.numpy()));\n", " \n", " ax[idx].set_title('{} ({:.2f})'.format(pred_label, prob)) " ] }, { "cell_type": "markdown", "metadata": { "id": "Duu8mX3iXANE" }, "source": [ "## License\n", "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "5UOm2etrwYCs" ], "name": "03a_transfer_learning.ipynb", "provenance": [], "toc_visible": true }, "environment": { "name": "tf2-2-3-gpu.2-3.m59", "type": "gcloud", "uri": "gcr.io/deeplearning-platform-release/tf2-2-3-gpu.2-3:m59" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.8" } }, "nbformat": 4, "nbformat_minor": 4 }