# ONNX Runtime

In this notebook, we will show how to use ONNX Runtime to accelerate inference of a model trained in PyTorch. In addition, we will use ONNX to quantize the model to int8 precision to further improve performance by reducing the memory footprint. We will train a simple model on the MNIST dataset and then convert it to ONNX format. We will then use ONNX Runtime to accelerate inference of the model. Finally, we will quantize the model to int8 precision



## Setup ONNX Runtime

First,  install torch, torchvision, onnx and onnxruntime. Then,  import neccesary module

In [None]:
%pip install torch torchvision
%pip install onnx onnxruntime

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.quantization
import pathlib
import numpy as np
import torch.onnx
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, quantize_static, CalibrationDataReader, QuantType

## Train Model

We will train a simple CNN model on the MNIST dataset. 

In [None]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,transform=transform)
test_dataset = datasets.MNIST('./data', train=False,transform=transform)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(12 * 13 * 13, 10)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)  
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  
        x = self.fc(x)
        output = F.log_softmax(x, dim=1)
        return output


train_loader = torch.utils.data.DataLoader(train_dataset, 32)
test_loader = torch.utils.data.DataLoader(test_dataset, 32)

device = "cpu"

epochs = 1

model = Net().to(device)
optimizer = optim.Adam(model.parameters())

model.train()

for epoch in range(1, epochs+1):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))

MODEL_DIR = pathlib.Path("./onnx_models")
MODEL_DIR.mkdir(exist_ok=True)
torch.save(model.state_dict(), MODEL_DIR / "original_model.p")

## Export to ONNX

After training, export the model to ONNX format.


In [None]:
x, _ = next(iter(train_loader))
torch.onnx.export(model,              
                  x,                         
                  MODEL_DIR / "mnist_model.onnx",  
                  export_params=True,        
                  opset_version=10,          
                  do_constant_folding=True,  
                  input_names = ['input'],   
                  output_names = ['output'], 
                  dynamic_axes={'input' : {0 : 'batch_size'},    
                                'output' : {0 : 'batch_size'}})

## Run Inference and Test Simalirity

Next, validate the converted model by running inference and comparing the results with the PyTorch model.

In [None]:
torch_out = model(x)

onnx_model = onnx.load(MODEL_DIR / "mnist_model.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession(MODEL_DIR / "mnist_model.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

## Quantization

### Dynamic Quantization

Dynamic quantization calculates the parameters to be quantized for activations dynamically. These calculations increase the accuracy of the model but increase the cost of inference as well.

In [None]:
!python -m onnxruntime.quantization.preprocess --input {MODEL_DIR / "mnist_model.onnx"} --output {MODEL_DIR / "mnist_model_processed.onnx"}

In [None]:
model_fp32 = MODEL_DIR / "mnist_model_processed.onnx"
model_quant = MODEL_DIR / "mnist_model_quant.onnx"
quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)

### Compare Size

Let's compare the size of the original model, the quantized model

In [None]:
%ls -lh {MODEL_DIR}

### Compare Accuracy

Let's compare the accuracy of the converted onnx model and the quantized model. The accuracy of the quantized model should be close to the original model

In [None]:
def test_onnx(model_name, data_loader):
    onnx_model = onnx.load(model_name)
    onnx.checker.check_model(onnx_model)
    ort_session = onnxruntime.InferenceSession(model_name)
    test_loss = 0
    correct = 0
    for data, target in data_loader:
        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
        output = ort_session.run(None, ort_inputs)[0]
        output = torch.from_numpy(output)
        test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(data_loader.dataset)

    return 100. * correct / len(data_loader.dataset)

acc = test_onnx(MODEL_DIR / "mnist_model.onnx", test_loader)
print(f"Accuracy of the original model is {acc}%")

qacc = test_onnx(MODEL_DIR / "mnist_model_quant.onnx", test_loader)
print(f"Accuracy of the quantized model is {qacc}%")


## Static Quantization

For static quantization method  parameters are quantized first using the calibration dataset. This method is faster than dynamic quantization but the accuracy is lower. Hence, calbration dataset need to be created using the `CalibrationDataReader` class.

In [None]:
class QuantDR(CalibrationDataReader):
    def __init__(self, torch_data_loader, input_name):
        self.torch_data_loader = torch_data_loader
        self.input_name = input_name
        self.datasize = len(torch_data_loader)
        self.enum_data = iter(torch_data_loader)

    def to_numpy(self, tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    def get_next(self):
        batch = next(self.enum_data, None)
        if batch is not None:
          return {self.input_name: self.to_numpy(batch[0])}
        else:
          return None

    def rewind(self):
        self.enum_data = iter(self.torch_data_loader)

calibration_data = QuantDR(train_loader, ort_session.get_inputs()[0].name)
model__static_quant = MODEL_DIR / "mnist_model_static_quant.onnx"
static_quant_model = quantize_static(model_fp32, model__static_quant, calibration_data, weight_type=QuantType.QInt8)

### Compare Size

Let's compare the size of the original model and the quantized model

In [None]:
%ls -lh {MODEL_DIR}

### Compare Accuracy

Let's compare the accuracy of the converted onnx model and the quantized model. The accuracy of the quantized model should be close to the original model

In [None]:
static_qacc = test_onnx(model__static_quant, test_loader)
print(f"Accuracy of the static quantized model is {static_qacc}%")