Here is a brief introduction of how to use our model converter. Here we use the pretrained MobileNetV2 model as an example.

In [17]:
import torch
import torchvision

model = torchvision.models.mobilenet_v2(pretrained=True)

Then, we will convert it to TFLite using TinyNerualNetwork.

In [18]:
import sys
sys.path.append('../..')

from tinynn.converter import TFLiteConverter

# Provide a viable input to the model
dummy_input = torch.randn(1, 3, 224, 224)
model_path = 'mobilenet_v2.tflite'

# Moving the model to cpu and set it to evaluation mode before model conversion
with torch.no_grad():
    model.cpu()
    model.eval()

    converter = TFLiteConverter(model, dummy_input, model_path)
    converter.convert()

INFO (tinynn.converter.base) Generated model saved to mobilenet_v2.tflite


Let's prepare an example input using an online image. 

In [25]:
import os
from torch.hub import download_url_to_file
from PIL import Image
from torchvision import transforms
import numpy as np

cwd = os.path.abspath(os.getcwd())
img_path = os.path.join(cwd, 'dog.jpg')

img_urls = [
    'https://github.com/pytorch/hub/raw/master/images/dog.jpg',
    'https://raw.fastgit.org/pytorch/hub/master/images/dog.jpg',
]

# If you have diffculties accessing Github, then you may try out the second link
download_url_to_file(img_urls[0], img_path)

img = Image.open(img_path)

mean = np.array([0.485, 0.456, 0.406], dtype='float32')
std = np.array([0.229, 0.224, 0.225], dtype='float32')

preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
)

# Image preprocessing
processed_img = preprocess(img)
arr = np.asarray(processed_img).astype('float32') / 255
normalized = (arr - mean) / std
input_arr = np.expand_dims(normalized, 0)


Let's run the generate TFLite model with the example input.

In [29]:
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], input_arr)
interpreter.invoke()
output_arr = interpreter.get_tensor(output_details[0]['index'])

print('TFLite out:', np.argmax(output_arr))

TFLite out: 258


Let's check whether the output is consistent with the one predicted by the original model.

In [28]:
torch_input = torch.from_numpy(input_arr).permute((0, 3, 1, 2))

with torch.no_grad():
    torch_output = model(torch_input)

print('PyTorch out:', torch.argmax(torch_output))

PyTorch out: tensor(258)
