tinynn/graph/quantization/utils.py (81 lines of code) (raw):

import torch def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False): if conv_b is None: conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) if transpose: shape = [1, -1] + [1] * (len(conv_w.shape) - 2) else: shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) fused_conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter( fused_conv_b, conv_b.requires_grad ) def fuse_bn_conv_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: conv_b = torch.zeros(conv_w.shape[0], dtype=conv_w.dtype, device=conv_w.device) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) shape = [1, -1] + [1] * (len(conv_w.shape) - 2) reduced_dims = [i for i in range(len(conv_w.shape)) if i > 1] fused_b = bn_b - bn_rm * bn_var_rsqrt * bn_w if conv_w.shape[1] == 1 and bn_rm.shape[0] > 1: offset_b = (conv_w.sum(dim=reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1) else: offset_b = conv_w.sum(dim=reduced_dims).matmul(fused_b.reshape(-1, 1)).reshape(-1) fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) fused_conv_b = conv_b + offset_b return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter( fused_conv_b, conv_b.requires_grad ) def get_parameter(mod: torch.nn.Module, param: str) -> torch.nn.Parameter: if param == 'weight': if isinstance(mod, torch.nn.Sequential): return getattr(mod[0], param) elif hasattr(mod, 'weight_fake_quant'): return mod.weight_fake_quant(getattr(mod, param)) elif hasattr(mod, 'set_weight_bias'): return getattr(mod, param)() else: return getattr(mod, param) elif param == 'bias': if isinstance(mod, torch.nn.Sequential): return getattr(mod[0], param) elif hasattr(mod, 'weight_fake_quant'): return getattr(mod, param) elif hasattr(mod, 'set_weight_bias'): return getattr(mod, param)() else: return getattr(mod, param) else: return getattr(mod, param) def clamp_with_fusion(x: torch.Tensor, min_val: float, max_val: float) -> torch.Tensor: if not x.is_quantized: return torch.clamp(x, min_val, max_val) return x def clamp_with_fusion_(x: torch.Tensor, min_val: float, max_val: float) -> torch.Tensor: if not x.is_quantized: return torch.clamp_(x, min_val, max_val) return x def fake_quantize(tensor, asym, eps, quant_max, quant_min): min_val, max_val = torch.aminmax(tensor) device = tensor.device zero_point = torch.zeros(min_val.size(), dtype=torch.int64, device=device) if not asym: max_val_pos = torch.max(-min_val, max_val) scale = max_val_pos / (float(quant_max - quant_min) / 2) scale = torch.max(scale, torch.tensor(eps)) else: scale = (max_val - min_val) / float(quant_max - quant_min) scale = torch.max(scale, torch.tensor(eps)) zero_point = quant_min - torch.round(min_val / scale).to(torch.int) zero_point = torch.clamp(zero_point, quant_min, quant_max) # do fake quantize return torch.fake_quantize_per_tensor_affine(tensor, scale, zero_point, quant_min, quant_max)