# Quantization

You may all know that quantization leads to smaller model size and faster model inference. But do you know why? Here we will cover the basics of quantization.

## Basic concept
For a floating value `f`, it can be expressed as an integral number `q = f / s + o` (aka quantized value) given the quantization parameters `s` as scale and `o` as offset.
When you convert `f` to `q`, the action is called "Quantize". Respectively, if you restore `f` from `q`, it is called "Dequantize".

Usually, we use 8-bit quantization, which means the value of `q` is within range `[0, 255]` (unsigned) or `[-128, 127]` (signed).

## Dynamic quantization & static quantization
Generally speaking, there are two kinds of quantization. Dynamic quantization (aka hybrid quantization or dynamic range quantization) is a kind of quantization that calculates the quantization parameters on the fly. All you need to do is to convert the weights to quantized values. You may refer to [this tutorial](../hybrid.ipynb) for performing this kind of quantization using TinyNerualNetwork.

Static quantization, on the other hand, have the quantization parameters calculated before the inference phase. There are generally two ways to achieve that, [quantization aware training](../qat.ipynb) and [post quantization](../post.ipynb). We will illustrate the process of it with the example in the next section.

If you want to choose one type of quantization without knowing the details, you may base on the decision tree in the graph or the summary table below.

![](https://www.tensorflow.org/lite/performance/images/quantization_decision_tree.png)

| Technique                  | Benefits                     | Hardware                        |
|----------------------------|------------------------------|---------------------------------|
| Dynamic quantization | 4x smaller, 2x-3x speedup    | CPU                             |
| Static quantization  | 4x smaller, 3x+ speedup      | CPU, Edge TPU, Microcontrollers |

## How static quantization is performed in DNN frameworks?
The key here is fake quantization. What is fake quantization? Suppose we have only one operation `y = conv(x)` in the original floating computation graph, then we want to have `y’ = q_conv(x‘)` in the quantized graph. 

With fake quantization, we have `x’ = fake_quantize(x)` and `y’ = fake_quantize(y)`.
First, we will observe the mininum and maximum values of `x` and `y`. Let's mark them as `x_min`, `x_max` and `y_min`, `y_max`. And then we calculate the quantization parameters, including scale `s` and offset `o`.

Asymmetric quantization:
```py
s = (f_max - f_min) / (q_max - q_min)
o = q_min - min(f_min, 0) / s
```

Symmetric quantization:
```py
s = max(f_max, -fmin) / ((q_max - q_min) / 2)
o = 128 [uint8]
o = 0 [int8]
```

Then, fake quantization is performed using the given quantization parameters. We have `x’ = fake_quantize(x) = (clamp(round(x / s + o), q_min, q_max) - o) * s`. Similarly, we can get `y‘`.

Finally, we replace the floating kernels with the quantized kernels. So the computation graph will contain the following operations.
```py
x’ = quantize(x, s_x, o_x)
y’ = q_conv(x)
y = dequantize(y’, s_y, o_y)
```

## Static quantization in PyTorch
We use the following PyTorch model as an example.

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        y = self.conv(x)
        y = self.bn(y)
        y = self.relu(y)
        return x + y

model = Model()

The first step is to decide which part of the model should run with the quantized kernels. Since all the operations in this model support quantization, we may just quantize all inputs and dequantize all outputs. So we will get the modified model below.

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.nn.quantized

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.fake_quant = torch.quantization.QuantStub() # Quantize
        self.fake_dequant = torch.quantization.DeQuantStub() # Dequantize
    
    def forward(self, x):
        x = self.fake_quant(x)
        y = self.conv(x)
        y = self.bn(y)
        y = self.relu(y)

        z  = x + y
        z = self.fake_dequant(z)
        return z

The second step is to find out all the requantizable functions in the model. Wait, what does `requantizable` mean? It means the operations that may generate outputs with a different set of quantization parameters. Typically, the list include `add`, `mul`, `add_relu` and `cat`. We will need to replace them with the ones under `torch.nn.quantized.FloatFunctional`.

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.nn.quantized

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.fake_quant = torch.quantization.QuantStub() # Quantize
        self.fake_dequant = torch.quantization.DeQuantStub() # Dequantize
        self.float_functional = torch.nn.quantized.FloatFunctional()
    
    def forward(self, x):
        x = self.fake_quant(x)
        y = self.conv(x)
        y = self.bn(y)
        y = self.relu(y)

        z  = self.float_functional.add(x, y)
        z = self.fake_dequant(z)
        return z

With the model given above, you may use it in quantization. Next, we will need to figure out the fusable nodes. Some nodes can be viewed as one module during quantization, e.g. Conv2d-BatchNorm2d-ReLU.

In [31]:
from distutils.version import LooseVersion

m = Model()
m.train()

m.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')

if LooseVersion(torch.__version__) >= LooseVersion('1.11.0'):
    torch.ao.quantization.fuse_modules_qat(m, [['conv', 'bn', 'relu']], inplace=True)
else:
    torch.quantization.fuse_modules(m, [['conv', 'bn', 'relu']], inplace=True)

print(m)

Model(
  (conv): ConvBnReLU2d(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (bn): Identity()
  (relu): Identity()
  (fake_quant): QuantStub()
  (fake_dequant): DeQuantStub()
  (float_functional): FloatFunctional(
    (activation_post_process): Identity()
  )
)


The final step before model training or calibration is to perform quantization preparation. After this step, the `FakeQuantize` nodes will be added to all the requantizable nodes.

In [32]:
torch.quantization.prepare_qat(m, inplace=True)
print(m)

Model(
  (conv): ConvBnReLU2d(
    3, 3, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (bn): Identity()
  (relu): Iden

Finally, we may start our training process. To save time, we implemented the simple logic for feeding the model with some randomly-generated data.

In [33]:
for _ in range(10):
    dummy_input = torch.randn(1, 3, 224, 224)
    m(dummy_input)

print(m)

Model(
  (conv): ConvBnReLU2d(
    3, 3, kernel_size=(1, 1), stride=(1, 1)
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0045]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.5689650177955627, max_val=0.4857633411884308)
    )
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([0.0171]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=4.370081901

As you see in the graph, the training process is carried out on the floating computation graph with the `FakeQuantize` nodes. So if you want a actual quantized model, we need to perform explicit conversion.

In [34]:
quantized_m = torch.quantization.convert(m)

print(quantized_m)

Model(
  (conv): QuantizedConvReLU2d(3, 3, kernel_size=(1, 1), stride=(1, 1), scale=0.017137575894594193, zero_point=0)
  (bn): Identity()
  (relu): Identity()
  (fake_quant): Quantize(scale=tensor([0.0340]), zero_point=tensor([122]), dtype=torch.quint8)
  (fake_dequant): DeQuantize()
  (float_functional): QFunctional(
    scale=0.045997653156518936, zero_point=87
    (activation_post_process): Identity()
  )
)


## Our quantization tool
As you can see, a lot of things have to be done to apply quantization to your PyTorch model. That's why we develop the quantization tools in TinyNeuralNetwork, which eases the task by adding only several lines to your code.

In [36]:
import sys
sys.path.append('../..')

from tinynn.graph.quantization.quantizer import QATQuantizer

quantizer = QATQuantizer(model, dummy_input, work_dir='out')
q_model = quantizer.quantize()

for _ in range(10):
    dummy_input = torch.randn(1, 3, 224, 224)
    q_model(dummy_input)

q_model = torch.quantization.convert(q_model)

print(q_model)

Model_qat(
  (fake_quant_0): Quantize(scale=tensor([0.0350]), zero_point=tensor([127]), dtype=torch.quint8)
  (conv): QuantizedConvReLU2d(3, 3, kernel_size=(1, 1), stride=(1, 1), scale=0.017351238057017326, zero_point=0)
  (bn): Identity()
  (relu): Identity()
  (fake_dequant_0): DeQuantize()
  (float_functional_simple_0): QFunctional(
    scale=0.03781222179532051, zero_point=94
    (activation_post_process): Identity()
  )
)
