tinynn/graph/quantization/fused_modules.py (103 lines of code) (raw):

import copy from distutils.version import LooseVersion import torch import torch.nn as nn import torch.nn.intrinsic as nni from .utils import fuse_conv_bn_weights, fuse_bn_conv_weights _FusedModule = getattr(nni, '_FusedModule', nn.Sequential) class ConvTransposeBn2d(_FusedModule): def __init__(self, conv, bn): assert ( type(conv) is nn.ConvTranspose2d and type(bn) is nn.BatchNorm2d ), 'Incorrect types for input modules{}{}'.format(type(conv), type(bn)) super(ConvTransposeBn2d, self).__init__(conv, bn) HAS_QAT_IN_FUNC = LooseVersion(torch.__version__) >= '1.11.0' def fuse_convtranspose_bn(is_qat, convt, bn): r"""Given ConvTranspose and bn modules, fuses them and returns the fused module Args: convt: Module instance of type ConvTransposeNd bn: BatchNormNd instance that needs to be fused with the linear layer. batch norm N should match the ConvTranspose N Examples:: >>> m1 = nn.ConvTranspose2d(10, 20, 3) >>> b1 = nn.BatchNorm2d(20) >>> # xdoctest: +SKIP >>> m2 = fuse_convtranspose_bn(m1, b1) """ assert convt.training == bn.training, "ConvTranspose and BN both must be in the same mode (train or eval)." if is_qat is None: is_qat = convt.training if is_qat: return ConvTransposeBn2d(convt, bn) else: if HAS_QAT_IN_FUNC: return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) else: fused_conv = copy.deepcopy(convt) fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, True ) return fused_conv def fuse_bn_conv(is_qat, bn, conv): assert conv.training == bn.training, "Conv and BN both must be in the same mode (train or eval)." if is_qat is None: is_qat = conv.training assert not is_qat, "BN Conv fusion for QAT is not yet supported" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = fuse_bn_conv_weights( fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias ) return fused_conv def fuse_bn_conv_relu(is_qat, bn, conv, relu): assert conv.training == bn.training == relu.training, "Conv, BN and ReLU must be in the same mode (train or eval)." if is_qat is None: is_qat = conv.training assert not is_qat, "BN Conv fusion for QAT is not yet supported" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = fuse_bn_conv_weights( fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias ) return nni.ConvReLU2d(fused_conv, relu) PATCH_MOD_MAPPING = { ( nn.ConvTranspose2d, nn.BatchNorm2d, ): fuse_convtranspose_bn, ( nn.BatchNorm2d, nn.Conv2d, ): fuse_bn_conv, ( nn.BatchNorm2d, nn.Conv2d, nn.ReLU, ): fuse_bn_conv_relu, } def gen_fuse_known_modules_wrapper(orig_fuse_known_modules): def fuse_known_modules(mod_list, *args, **kwargs): types = tuple(type_before_parametrizations(m) for m in mod_list) if HAS_QAT_IN_FUNC: is_qat = kwargs.get('is_qat', args[0]) else: is_qat = None if types in PATCH_MOD_MAPPING: new_mod = [None] * len(mod_list) fuse_func = PATCH_MOD_MAPPING[types] fused = fuse_func(is_qat, *mod_list) # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion # Move pre forward hooks of the base module to resulting fused module for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items(): fused.register_forward_pre_hook(pre_hook_fn) del mod_list[0]._forward_pre_hooks[handle_id] # Move post forward hooks of the last module to resulting fused module for handle_id, hook_fn in mod_list[-1]._forward_hooks.items(): fused.register_forward_hook(hook_fn) del mod_list[-1]._forward_hooks[handle_id] new_mod[0] = fused for i in range(1, len(mod_list)): identity = nn.Identity() identity.training = mod_list[0].training new_mod[i] = identity return new_mod else: return orig_fuse_known_modules(mod_list, *args, **kwargs) return fuse_known_modules def type_before_parametrizations(module): has_parametrize = hasattr(torch.nn.utils, 'parametrize') has_is_parametrized = has_parametrize and hasattr(torch.nn.utils.parametrize, 'is_parametrized') has_type_before_parametrizations = has_parametrize and hasattr( torch.nn.utils.parametrize, 'type_before_parametrizations' ) if has_type_before_parametrizations: return torch.nn.utils.parametrize.type_before_parametrizations(module) elif has_is_parametrized and torch.nn.utils.parametrize.is_parametrized(module): return module.__class__.__bases__[0] else: return type(module)