# Pytorch Quantization

PyTorch supports INT8 quantization compared to typical FP32 models allowing for a 4x reduction in the model size and a 4x reduction in memory bandwidth requirements
while still achieving comparable accuracy for many applications. This notebook demonstrates how to quantize a model from FP32 to INT8 using PyTorch's quantization tooling. We will train a simple CNN model on mnist and then quantize it using the quantization tooling and compare the accuracy and size of the quantized model with the original FP32 model.

## Setup PyTorch

First, let's install PyTorch and torchvision and the import the required modules.


In [None]:
%pip install torch torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.quantization
import pathlib

## Dynamic Quantization

For dynamic quantization, weights are quantized but activations are read or stored in floating point and the activations are only quantized for compute.

### Load MNIST dataset 

First, we load the MNIST dataset

In [None]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,transform=transform)
test_dataset = datasets.MNIST('./data', train=False,transform=transform)

### Train the Model

Next, we define a simple CNN model and then train on the MNIST dataset

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(12 * 13 * 13, 10)

    def forward(self, x):
        x = x.reshape(-1, 1, 28, 28)  
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = x.reshape(x.size(0), -1)  
        x = self.fc(x)
        output = F.log_softmax(x, dim=1)
        return output


train_loader = torch.utils.data.DataLoader(train_dataset, 32)
test_loader = torch.utils.data.DataLoader(test_dataset, 32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epochs = 1

model = Net().to(device)
optimizer = optim.Adam(model.parameters())

model.train()

for epoch in range(1, epochs+1):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))

### Quantize Model

After training, we can quantize the model using the using the `torch.quantization.quantize_dynamic` function from pytorch.

In [None]:
model.to('cpu')
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

### Check Model Size

We can see that the quantized model is much smaller than the original model

In [None]:
models_dir = pathlib.Path("./models/")
models_dir.mkdir(exist_ok=True, parents=True)
torch.save(model.state_dict(), "./models/original_model.p")
torch.save(quantized_model.state_dict(), "./models/quantized_model.p")

%ls -lh models

### Check Accuracy

We can see that the quantized model has comparable accuracy to the original model

In [None]:
def test(model, device, data_loader, quantized=False):
    model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(data_loader.dataset)

    return 100. * correct / len(data_loader.dataset)

original_acc = test(model, "cpu", test_loader)
quantized_acc = test(quantized_model, "cpu", test_loader)

print('Original model accuracy: {:.0f}%'.format(original_acc))
print('Quantized model accuracy: {:.0f}%'.format(quantized_acc))

## Post-training Static Quantization

Post-training static quantization is where weights and activations are quantized and calibration is required post training. Here we quantize the model using the `torch.quantization.quantize_fx()` function from PyTorch and compare the accuracy and size of the quantized model with the original FP32 model.

To quantize using post-training static quantization tool, first define a model or load a pre-trained model and then create quantization configuration mapping using the default for the QNNPACK engine. Set the model to evaluation mode and create a sample input tensor. Then, prepare the model for quantization using the `quantize_fx.prepare_fx()` function. This involves applying the quantization configuration mapping and preparing the model to handle int8 precision. The prepared model is then executed on the input tensor. Finally, the quantized model by calling`quantize_fx.convert_fx()` and saved the model to disk.


In [None]:
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

loaded_model = Net()
loaded_model.load_state_dict(torch.load("./models/original_model.p"))
model_to_quantize = copy.deepcopy(loaded_model)

qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()

input_fp32 = next(iter(test_loader))[0][0:1]
input_fp32.to('cpu')

model_fp32_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, input_fp32)
model_fp32_prepared(input_fp32)
model_int8 = quantize_fx.convert_fx(model_fp32_prepared)

torch.save(model_int8.state_dict(), "./models/post_quantized_model.p")

## Check Model Size

Again, we can see that the quantized model is much smaller than the original model

In [None]:
%ls -lh models

## Check Accuracy

Again, we can see that the quantized model accuarcy is not much difference than the original accuracy

In [None]:
quantized_acc = test(model_int8, "cpu", test_loader, quantized=True)
print('Post quantized model accuracy: {:.0f}%'.format(quantized_acc))