tutorials/model_conversion/basic.ipynb (201 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a brief introduction of how to use our model converter. Here we use the pretrained MobileNetV2 model as an example."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"model = torchvision.models.mobilenet_v2(pretrained=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we will convert it to TFLite using TinyNerualNetwork."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO (tinynn.converter.base) Generated model saved to mobilenet_v2.tflite\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('../..')\n",
"\n",
"from tinynn.converter import TFLiteConverter\n",
"\n",
"# Provide a viable input to the model\n",
"dummy_input = torch.randn(1, 3, 224, 224)\n",
"model_path = 'mobilenet_v2.tflite'\n",
"\n",
"# Moving the model to cpu and set it to evaluation mode before model conversion\n",
"with torch.no_grad():\n",
" model.cpu()\n",
" model.eval()\n",
"\n",
" converter = TFLiteConverter(model, dummy_input, model_path)\n",
" converter.convert()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's prepare an example input using an online image. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from torch.hub import download_url_to_file\n",
"from PIL import Image\n",
"from torchvision import transforms\n",
"import numpy as np\n",
"\n",
"cwd = os.path.abspath(os.getcwd())\n",
"img_path = os.path.join(cwd, 'dog.jpg')\n",
"\n",
"img_urls = [\n",
" 'https://github.com/pytorch/hub/raw/master/images/dog.jpg',\n",
" 'https://raw.fastgit.org/pytorch/hub/master/images/dog.jpg',\n",
"]\n",
"\n",
"# If you have diffculties accessing Github, then you may try out the second link\n",
"download_url_to_file(img_urls[0], img_path)\n",
"\n",
"img = Image.open(img_path)\n",
"\n",
"mean = np.array([0.485, 0.456, 0.406], dtype='float32')\n",
"std = np.array([0.229, 0.224, 0.225], dtype='float32')\n",
"\n",
"preprocess = transforms.Compose(\n",
" [\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" ]\n",
")\n",
"\n",
"# Image preprocessing\n",
"processed_img = preprocess(img)\n",
"arr = np.asarray(processed_img).astype('float32') / 255\n",
"normalized = (arr - mean) / std\n",
"input_arr = np.expand_dims(normalized, 0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run the generate TFLite model with the example input."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TFLite out: 258\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"interpreter = tf.lite.Interpreter(model_path=model_path)\n",
"interpreter.allocate_tensors()\n",
"\n",
"input_details = interpreter.get_input_details()\n",
"output_details = interpreter.get_output_details()\n",
"\n",
"interpreter.set_tensor(input_details[0]['index'], input_arr)\n",
"interpreter.invoke()\n",
"output_arr = interpreter.get_tensor(output_details[0]['index'])\n",
"\n",
"print('TFLite out:', np.argmax(output_arr))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check whether the output is consistent with the one predicted by the original model."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch out: tensor(258)\n"
]
}
],
"source": [
"torch_input = torch.from_numpy(input_arr).permute((0, 3, 1, 2))\n",
"\n",
"with torch.no_grad():\n",
" torch_output = model(torch_input)\n",
"\n",
"print('PyTorch out:', torch.argmax(torch_output))"
]
}
],
"metadata": {
"interpreter": {
"hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608"
},
"kernelspec": {
"display_name": "Python 3.8.6 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}