tutorials/model_conversion/basic.ipynb (201 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Here is a brief introduction of how to use our model converter. Here we use the pretrained MobileNetV2 model as an example." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "\n", "model = torchvision.models.mobilenet_v2(pretrained=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we will convert it to TFLite using TinyNerualNetwork." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO (tinynn.converter.base) Generated model saved to mobilenet_v2.tflite\n" ] } ], "source": [ "import sys\n", "sys.path.append('../..')\n", "\n", "from tinynn.converter import TFLiteConverter\n", "\n", "# Provide a viable input to the model\n", "dummy_input = torch.randn(1, 3, 224, 224)\n", "model_path = 'mobilenet_v2.tflite'\n", "\n", "# Moving the model to cpu and set it to evaluation mode before model conversion\n", "with torch.no_grad():\n", " model.cpu()\n", " model.eval()\n", "\n", " converter = TFLiteConverter(model, dummy_input, model_path)\n", " converter.convert()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's prepare an example input using an online image. " ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "import os\n", "from torch.hub import download_url_to_file\n", "from PIL import Image\n", "from torchvision import transforms\n", "import numpy as np\n", "\n", "cwd = os.path.abspath(os.getcwd())\n", "img_path = os.path.join(cwd, 'dog.jpg')\n", "\n", "img_urls = [\n", " 'https://github.com/pytorch/hub/raw/master/images/dog.jpg',\n", " 'https://raw.fastgit.org/pytorch/hub/master/images/dog.jpg',\n", "]\n", "\n", "# If you have diffculties accessing Github, then you may try out the second link\n", "download_url_to_file(img_urls[0], img_path)\n", "\n", "img = Image.open(img_path)\n", "\n", "mean = np.array([0.485, 0.456, 0.406], dtype='float32')\n", "std = np.array([0.229, 0.224, 0.225], dtype='float32')\n", "\n", "preprocess = transforms.Compose(\n", " [\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " ]\n", ")\n", "\n", "# Image preprocessing\n", "processed_img = preprocess(img)\n", "arr = np.asarray(processed_img).astype('float32') / 255\n", "normalized = (arr - mean) / std\n", "input_arr = np.expand_dims(normalized, 0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's run the generate TFLite model with the example input." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TFLite out: 258\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "interpreter = tf.lite.Interpreter(model_path=model_path)\n", "interpreter.allocate_tensors()\n", "\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details()\n", "\n", "interpreter.set_tensor(input_details[0]['index'], input_arr)\n", "interpreter.invoke()\n", "output_arr = interpreter.get_tensor(output_details[0]['index'])\n", "\n", "print('TFLite out:', np.argmax(output_arr))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check whether the output is consistent with the one predicted by the original model." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PyTorch out: tensor(258)\n" ] } ], "source": [ "torch_input = torch.from_numpy(input_arr).permute((0, 3, 1, 2))\n", "\n", "with torch.no_grad():\n", " torch_output = model(torch_input)\n", "\n", "print('PyTorch out:', torch.argmax(torch_output))" ] } ], "metadata": { "interpreter": { "hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608" }, "kernelspec": { "display_name": "Python 3.8.6 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }