tinynn/graph/quantization/qat_modules.py (572 lines of code) (raw):

import math from distutils.version import LooseVersion from typing import TypeVar import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.intrinsic import ConvReLU1d from torch.nn.modules.utils import _pair from . import fused_modules as fm from .utils import fuse_conv_bn_weights class Conv1d(nn.Conv1d): r""" A Conv1d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv1d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv1d#torch.nn.Conv1d for documentation. Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = nn.Conv1d def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None, device=None, dtype=None, ) -> None: if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): factory_kwargs = {'device': device, 'dtype': dtype} else: factory_kwargs = {} super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, **factory_kwargs, ) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig if LooseVersion(torch.__version__) < LooseVersion('1.7.0'): self.activation_post_process = qconfig.activation() if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) else: self.weight_fake_quant = qconfig.weight() def _conv_forward(self, input, weight, bias): if LooseVersion(torch.__version__) < '1.8.0': if self.padding_mode != 'zeros': return F.conv1d( F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, torch.nn.utils._single(0), self.dilation, self.groups, ) return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) else: return super()._conv_forward(input, weight, bias) def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' if type(mod) is ConvReLU1d: mod = mod[0] qconfig = mod.qconfig qat_conv = cls( mod.in_channels, mod.out_channels, mod.kernel_size, stride=mod.stride, padding=mod.padding, dilation=mod.dilation, groups=mod.groups, bias=mod.bias is not None, padding_mode=mod.padding_mode, qconfig=qconfig, ) qat_conv.weight = mod.weight qat_conv.bias = mod.bias return qat_conv def to_float(self): conv = torch.nn.Conv1d( self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias is not None, self.padding_mode, ) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) return conv class ConvTranspose1d(nn.ConvTranspose1d): r""" A ConvTranspose1d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.ConvTranspose1d`, please see https://pytorch.org/docs/stable/nn.html?highlight=convtranspose1d#torch.nn.ConvTranspose1d for documentation. Similar to `torch.nn.ConvTranspose1d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = nn.ConvTranspose1d def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, output_padding=0, groups: int = 1, bias: bool = True, dilation=1, padding_mode: str = 'zeros', qconfig=None, device=None, dtype=None, ) -> None: if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): factory_kwargs = {'device': device, 'dtype': dtype} else: factory_kwargs = {} super(ConvTranspose1d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, **factory_kwargs, ) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig if LooseVersion(torch.__version__) < LooseVersion('1.7.0'): self.activation_post_process = qconfig.activation() if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) else: self.weight_fake_quant = qconfig.weight() def forward(self, input, output_size=None): if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, self.dilation ) return F.conv_transpose1d( input, self.weight_fake_quant(self.weight), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation, ) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig qat_conv = cls( mod.in_channels, mod.out_channels, mod.kernel_size, stride=mod.stride, padding=mod.padding, output_padding=mod.output_padding, groups=mod.groups, bias=mod.bias is not None, dilation=mod.dilation, padding_mode=mod.padding_mode, qconfig=qconfig, ) qat_conv.weight = mod.weight qat_conv.bias = mod.bias return qat_conv class ConvTranspose2d(nn.ConvTranspose2d): r""" A ConvTranspose2d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.ConvTranspose2d`, please see https://pytorch.org/docs/stable/nn.html?highlight=convtranspose2d#torch.nn.ConvTranspose2d for documentation. Similar to `torch.nn.ConvTranspose2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = nn.ConvTranspose2d def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, output_padding=0, groups: int = 1, bias: bool = True, dilation=1, padding_mode: str = 'zeros', qconfig=None, device=None, dtype=None, ) -> None: if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): factory_kwargs = {'device': device, 'dtype': dtype} else: factory_kwargs = {} super(ConvTranspose2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, **factory_kwargs, ) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig if LooseVersion(torch.__version__) < LooseVersion('1.7.0'): self.activation_post_process = qconfig.activation() if LooseVersion(torch.__version__) >= LooseVersion('1.9.0'): self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) else: self.weight_fake_quant = qconfig.weight() def forward(self, input, output_size=None): if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') assert isinstance(self.padding, tuple) output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, self.dilation ) return F.conv_transpose2d( input, self.weight_fake_quant(self.weight), self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation, ) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig qat_conv = cls( mod.in_channels, mod.out_channels, mod.kernel_size, stride=mod.stride, padding=mod.padding, output_padding=mod.output_padding, groups=mod.groups, bias=mod.bias is not None, dilation=mod.dilation, padding_mode=mod.padding_mode, qconfig=qconfig, ) qat_conv.weight = mod.weight qat_conv.bias = mod.bias return qat_conv _BN_CLASS_MAP = { 1: nn.BatchNorm1d, 2: nn.BatchNorm2d, 3: nn.BatchNorm3d, } MOD = TypeVar('MOD', bound=nn.modules.conv._ConvTransposeNd) class _ConvTransposeBnNd(nn.modules.conv._ConvTransposeNd, fm._FusedModule): _version = 2 _FLOAT_MODULE = MOD def __init__( self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None, dim=2, ): nn.modules.conv._ConvTransposeNd.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode, ) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = nn.Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() self._enable_slow_path_for_better_numerical_stability = False def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() torch.nn.init.uniform_(self.bn.weight) torch.nn.init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound) def reset_parameters(self): super(_ConvTransposeBnNd, self).reset_parameters() def update_bn_stats(self): self.freeze_bn = False self.bn.training = True return self def freeze_bn_stats(self): self.freeze_bn = True self.bn.training = False return self def _forward(self, input): assert self.bn.running_var is not None running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std weight_shape = [1] * len(self.weight.shape) weight_shape[1] = -1 bias_shape = [1] * len(self.weight.shape) bias_shape[1] = -1 scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) # using zero bias here since the bias for original conv # will be added later if self.bias is not None: zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) else: zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype) conv = self._conv_forward(input, scaled_weight, zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape(bias_shape) conv = self.bn(conv_orig) return conv def extra_repr(self): # TODO(jerryzh): extend return super(_ConvTransposeBnNd, self).extra_repr() def forward(self, input): return self._forward(input) def train(self, mode=True): """ Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.freeze_bn: for module in self.children(): module.train(mode) return self # ===== Serialization version history ===== # # Version 1/None # self # |--- weight : Tensor # |--- bias : Tensor # |--- gamma : Tensor # |--- beta : Tensor # |--- running_mean : Tensor # |--- running_var : Tensor # |--- num_batches_tracked : Tensor # # Version 2 # self # |--- weight : Tensor # |--- bias : Tensor # |--- bn : Module # |--- weight : Tensor (moved from v1.self.gamma) # |--- bias : Tensor (moved from v1.self.beta) # |--- running_mean : Tensor (moved from v1.self.running_mean) # |--- running_var : Tensor (moved from v1.self.running_var) # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get('version', None) if version is None or version == 1: # BN related parameters and buffers were moved into the BN module for v2 v2_to_v1_names = { 'bn.weight': 'gamma', 'bn.bias': 'beta', 'bn.running_mean': 'running_mean', 'bn.running_var': 'running_var', 'bn.num_batches_tracked': 'num_batches_tracked', } for v2_name, v1_name in v2_to_v1_names.items(): if prefix + v1_name in state_dict: state_dict[prefix + v2_name] = state_dict[prefix + v1_name] state_dict.pop(prefix + v1_name) elif prefix + v2_name in state_dict: # there was a brief period where forward compatibility # for this module was broken (between # https://github.com/pytorch/pytorch/pull/38478 # and https://github.com/pytorch/pytorch/pull/38820) # and modules emitted the v2 state_dict format while # specifying that version == 1. This patches the forward # compatibility issue by allowing the v2 style entries to # be used. pass elif strict: missing_keys.append(prefix + v2_name) super(_ConvTransposeBnNd, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound # has no __name__ (code is fine though) assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) # type: ignore[attr-defined] assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls( conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.output_padding, conv.groups, conv.bias is not None, conv.dilation, conv.padding_mode, bn.eps, bn.momentum, False, qconfig, ) qat_convbn.weight = conv.weight qat_convbn.bias = conv.bias qat_convbn.bn.weight = bn.weight qat_convbn.bn.bias = bn.bias qat_convbn.bn.running_mean = bn.running_mean qat_convbn.bn.running_var = bn.running_var # mypy error: Cannot determine type of 'num_batches_tracked' qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type] return qat_convbn def to_float(self): cls = type(self) conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.output_padding, self.groups, self.bias is not None, self.dilation, self.padding_mode, ) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] # fuse bn into conv conv.weight, conv.bias = fuse_conv_bn_weights( conv.weight, conv.bias, self.bn.running_mean, self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias, True, ) if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] modules = [] modules.append(conv) relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] modules.append(relu) conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] conv_relu.train(self.training) return conv_relu else: conv.train(self.training) return conv class ConvTransposeBn2d(_ConvTransposeBnNd, nn.ConvTranspose2d): r""" A ConvTransposeBn2d module is a module fused from ConvTranspose2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. Similar to :class:`torch.nn.ConvTranspose2d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE = fm.ConvTransposeBn2d _FLOAT_CONV_MODULE = nn.ConvTranspose2d _FLOAT_BN_MODULE = nn.BatchNorm2d _FLOAT_RELU_MODULE = None def __init__( self, # ConvTransposeNd args in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=None, dilation=1, padding_mode='zeros', # BatchNorm2d args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None, ): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) _ConvTransposeBnNd.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, False, output_padding, groups, bias, padding_mode, eps, momentum, freeze_bn, qconfig, dim=2, ) def _conv_forward(self, input, weight, bias, output_size=None): if self.padding_mode != 'zeros': raise ValueError('Only `zeros` padding mode is supported for _ConvTransposeNd') assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, self.dilation ) # type: ignore[arg-type] return F.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation ) @classmethod def transform(cls, mod): conv = ConvTranspose2d( in_channels=mod.in_channels, out_channels=mod.out_channels, kernel_size=mod.kernel_size, stride=mod.stride, padding=mod.padding, output_padding=mod.output_padding, groups=mod.groups, bias=mod.bias is not None, dilation=mod.dilation, padding_mode=mod.padding_mode, qconfig=mod.qconfig, ) conv.weight, conv.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias, transpose=True, ) conv.weight_fake_quant = mod.weight_fake_quant conv.activation_post_process = mod.activation_post_process return conv