tinynn/graph/quantization/fake_quantize.py (86 lines of code) (raw):
import torch
from torch.nn import Module
from torch.quantization.observer import _with_args
class FakeQuantizeBFloat16(Module):
"""Simulate the quantize and dequantize operations in training time for bfloat16"""
def __init__(self, **observer_kwargs):
super(FakeQuantizeBFloat16, self).__init__()
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
def enable_fake_quant(self, enabled=True):
self.fake_quant_enabled[0] = 1 if enabled else 0
def disable_fake_quant(self):
self.enable_fake_quant(False)
def forward(self, X):
if self.fake_quant_enabled[0] == 1:
if isinstance(X, (tuple, list)):
return [self.forward(x) for x in X]
elif isinstance(X, torch.Tensor) and X.is_floating_point():
dtype = X.dtype
X = X.to(dtype=torch.bfloat16).to(dtype=dtype)
return X
with_args = classmethod(_with_args)
class FakeQuantizeTFLite(torch.quantization.FakeQuantize):
def forward(self, X):
observer_enabled = self.observer_enabled[0] == 1
fake_quant_enabled = self.fake_quant_enabled[0] == 1
if observer_enabled:
if fake_quant_enabled:
torch.quantization.disable_fake_quant(self)
X = super().forward(X)
if fake_quant_enabled:
torch.quantization.enable_fake_quant(self)
if fake_quant_enabled:
if observer_enabled:
torch.quantization.disable_observer(self)
X = X + self.scale * 1e-6 * torch.sign(X.detach())
X = super().forward(X)
if observer_enabled:
torch.quantization.enable_observer(self)
return X
def disable_fake_quant(mod):
"""
Disable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(tinynn.graph.quantization.disable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBFloat16):
mod.disable_fake_quant()
def enable_fake_quant(mod):
"""
Enable fake quantization for this module, if applicable. Example usage::
# model is any PyTorch model
model.apply(tinynn.graph.quantization.disable_fake_quant)
"""
if isinstance(mod, FakeQuantizeBFloat16):
mod.enable_fake_quant()
class PTQFakeQuantize(torch.quantization.FakeQuantize):
"""Using fake-quantize to do PTQ, speed up quantization error analyze.
When doing calibrate, please enable observer and disable fake-quant,
when validating model with quantization error, please enable fake-quant and disable observer.
"""
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
if self.fake_quant_enabled[0] == 1:
if (self.scale == 1 and self.zero_point == 0) or (
float(self.scale * (self.quant_max - self.quant_min + 1)) == 256
):
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max
)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale, self.zero_point, self.quant_min, self.quant_max
)
return X
def set_ptq_fake_quantize(name, module):
weight_fq = PTQFakeQuantize.with_args(
observer=torch.quantization.MinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False,
)
asym_fq = PTQFakeQuantize.with_args(
observer=torch.quantization.HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
reduce_range=False,
)
qconfig_new = torch.quantization.QConfig(asym_fq, weight_fq)
return qconfig_new