in tinynn/graph/quantization/quantizer.py [0:0]
def prepare_onnx_export(self, q_model):
"""Prepares for ONNX model export
Args:
q_model (nn.Module): The QAT/PTQ-prepared model
"""
if self.backend == 'onnx':
for mod in self.leaf_nodes:
torch.quantization.disable_fake_quant(mod.activation_post_process)
fq_base_cls = getattr(torch_q, 'FakeQuantizeBase', torch_q.FakeQuantize)
for n, m in q_model.named_modules():
if isinstance(m, fq_base_cls):
if m.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
m.zero_point.fill_(0)
else:
assert m.qscheme == torch.per_tensor_affine
assert isinstance(m, torch_q.FixedQParamsFakeQuantize)
m.qscheme = torch.per_tensor_symmetric
if torch.all(m.zero_point == 128):
m.zero_point.fill_(0)
else:
assert not torch.is_nonzero(m.zero_point)
m.scale *= 2
m.quant_min = -128
m.quant_max = 127
m.dtype = torch.int8
if hasattr(m, 'activation_post_process'):
m.activation_post_process.quant_min = -128
m.activation_post_process.quant_max = 127
m.activation_post_process.dtype = torch.int8
if LooseVersion(torch.__version__) >= LooseVersion('1.12.0'):
def get_wrapper(mod):
def wrapper(*args, **kwargs):
return torch_q.FakeQuantize.forward(mod, *args, **kwargs)
return wrapper
fused_fq_cls = getattr(torch_q, 'FusedMovingAvgObsFakeQuantize', None)
if fused_fq_cls is not None:
for m in q_model.modules():
if isinstance(m, fused_fq_cls):
m.forward = get_wrapper(m)
elif LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
mod_dict = {}
for n, m in q_model.named_modules():
mod_dict[n] = m
class _FakeQuantize(nn.Module):
def __init__(self, fq: torch_q.FakeQuantize) -> None:
super().__init__()
self.fake_quant_enabled = fq.fake_quant_enabled
self.scale = fq.scale
self.zero_point = fq.zero_point
self.quant_min = fq.quant_min
self.quant_max = fq.quant_max
self.is_per_channel = fq.is_per_channel
self.ch_axis = fq.ch_axis
def forward(self, X):
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max
)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale, self.zero_point, self.quant_min, self.quant_max
)
return X
action_list = []
for n, m in q_model.named_modules():
if isinstance(m, torch_q.FakeQuantize):
if '.' in n:
n_split = n.split('.')
prop = n_split[-1]
pm_name = '.'.join(n_split[:-1])
pm = mod_dict[pm_name]
else:
prop = n
pm = q_model
new_m = _FakeQuantize(m)
action_list.append((pm, prop, new_m))
for parent, prop, new_mod in action_list:
setattr(parent, prop, new_mod)
for n, m in q_model.named_modules():
def conv_fused_wrapper(mod, scale_factor):
type_name = type(mod).__name__
def new_wrapper(input):
weight_shape = [1] * len(mod.weight.shape)
weight_shape[0] = -1
y = type(mod)._conv_forward(
mod, input, mod.weight_fake_quant(mod.weight * scale_factor.reshape(weight_shape)), mod.bias
)
if 'Bn' in type_name:
y = mod.bn(y)
if 'ReLU' in type_name:
y = torch.nn.functional.relu(y)
return y
return new_wrapper
type_name = type(m).__name__
if type_name in ('ConvBn1d', 'ConvBnReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d'):
running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
scale_factor = m.bn.weight / running_std
if m.bias is not None:
m.bn.running_mean -= m.bias
m.bias = None
m.bn.running_mean *= scale_factor
m.bn.weight /= scale_factor
m.forward = conv_fused_wrapper(m, scale_factor)
def get_pre_hook(acps, indices):
def pre_hook(module, input):
new_input = list(input)
for acp, idx in zip(acps, indices):
new_input[idx] = acp(new_input[idx])
return tuple(new_input)
return pre_hook
def get_hook_func(acps, indices, orig_func):
def pre_hook(*input, **kwargs):
new_input = list(input)
for acp, idx in zip(acps, indices):
if orig_func.__name__ == 'cat':
new_input[0][idx] = acp(new_input[0][idx])
else:
new_input[idx] = acp(new_input[idx])
input = tuple(new_input)
return orig_func(*input, **kwargs)
return pre_hook
if self.backend == 'onnx':
end_mod_dict = {}
for start_mod, end_mod, idx in self.swap_nodes:
end_mod_dict.setdefault(end_mod, ([], []))
end_mod_dict[end_mod][0].append(start_mod)
end_mod_dict[end_mod][1].append(idx)
for end_mod, (start_mods, indices) in end_mod_dict.items():
acps = []
for start_mod in start_mods:
if inspect.isroutine(start_mod) and isinstance(start_mod.__self__, nnq.FloatFunctional):
acp = start_mod.__self__.activation_post_process
else:
acp = start_mod.activation_post_process
acps.append(acp)
if inspect.isroutine(end_mod) and isinstance(end_mod.__self__, nnq.FloatFunctional):
ff = end_mod.__self__
for i in range(len(indices)):
indices[i] -= 1
ff.cat = get_hook_func(acps, indices, ff.cat)
ff.add = get_hook_func(acps, indices, ff.add)
ff.mul = get_hook_func(acps, indices, ff.mul)
ff.add_scalar = get_hook_func(acps, indices, ff.add_scalar)
ff.mul_scalar = get_hook_func(acps, indices, ff.mul_scalar)
ff.add_relu = get_hook_func(acps, indices, ff.add_relu)
else:
assert isinstance(
end_mod, nn.Module
), "Only end nodes with `nn.Module` are supported duing module swapping"
end_mod.register_forward_pre_hook(get_pre_hook(acps, indices))
q_model.apply(torch_q.disable_observer)