def prepare_onnx_export()

in tinynn/graph/quantization/quantizer.py [0:0]


    def prepare_onnx_export(self, q_model):
        """Prepares for ONNX model export

        Args:
            q_model (nn.Module): The QAT/PTQ-prepared model

        """

        if self.backend == 'onnx':
            for mod in self.leaf_nodes:
                torch.quantization.disable_fake_quant(mod.activation_post_process)

        fq_base_cls = getattr(torch_q, 'FakeQuantizeBase', torch_q.FakeQuantize)
        for n, m in q_model.named_modules():
            if isinstance(m, fq_base_cls):
                if m.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
                    m.zero_point.fill_(0)
                else:
                    assert m.qscheme == torch.per_tensor_affine
                    assert isinstance(m, torch_q.FixedQParamsFakeQuantize)

                    m.qscheme = torch.per_tensor_symmetric
                    if torch.all(m.zero_point == 128):
                        m.zero_point.fill_(0)
                    else:
                        assert not torch.is_nonzero(m.zero_point)
                        m.scale *= 2

                m.quant_min = -128
                m.quant_max = 127
                m.dtype = torch.int8

                if hasattr(m, 'activation_post_process'):
                    m.activation_post_process.quant_min = -128
                    m.activation_post_process.quant_max = 127
                    m.activation_post_process.dtype = torch.int8

        if LooseVersion(torch.__version__) >= LooseVersion('1.12.0'):

            def get_wrapper(mod):
                def wrapper(*args, **kwargs):
                    return torch_q.FakeQuantize.forward(mod, *args, **kwargs)

                return wrapper

            fused_fq_cls = getattr(torch_q, 'FusedMovingAvgObsFakeQuantize', None)
            if fused_fq_cls is not None:
                for m in q_model.modules():
                    if isinstance(m, fused_fq_cls):
                        m.forward = get_wrapper(m)
        elif LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
            mod_dict = {}
            for n, m in q_model.named_modules():
                mod_dict[n] = m

            class _FakeQuantize(nn.Module):
                def __init__(self, fq: torch_q.FakeQuantize) -> None:
                    super().__init__()
                    self.fake_quant_enabled = fq.fake_quant_enabled
                    self.scale = fq.scale
                    self.zero_point = fq.zero_point
                    self.quant_min = fq.quant_min
                    self.quant_max = fq.quant_max
                    self.is_per_channel = fq.is_per_channel
                    self.ch_axis = fq.ch_axis

                def forward(self, X):
                    if self.fake_quant_enabled[0] == 1:
                        if self.is_per_channel:
                            X = torch.fake_quantize_per_channel_affine(
                                X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max
                            )
                        else:
                            X = torch.fake_quantize_per_tensor_affine(
                                X, self.scale, self.zero_point, self.quant_min, self.quant_max
                            )
                    return X

            action_list = []
            for n, m in q_model.named_modules():
                if isinstance(m, torch_q.FakeQuantize):
                    if '.' in n:
                        n_split = n.split('.')
                        prop = n_split[-1]
                        pm_name = '.'.join(n_split[:-1])
                        pm = mod_dict[pm_name]
                    else:
                        prop = n
                        pm = q_model

                    new_m = _FakeQuantize(m)

                    action_list.append((pm, prop, new_m))

            for parent, prop, new_mod in action_list:
                setattr(parent, prop, new_mod)

        for n, m in q_model.named_modules():

            def conv_fused_wrapper(mod, scale_factor):
                type_name = type(mod).__name__

                def new_wrapper(input):
                    weight_shape = [1] * len(mod.weight.shape)
                    weight_shape[0] = -1
                    y = type(mod)._conv_forward(
                        mod, input, mod.weight_fake_quant(mod.weight * scale_factor.reshape(weight_shape)), mod.bias
                    )
                    if 'Bn' in type_name:
                        y = mod.bn(y)
                    if 'ReLU' in type_name:
                        y = torch.nn.functional.relu(y)
                    return y

                return new_wrapper

            type_name = type(m).__name__
            if type_name in ('ConvBn1d', 'ConvBnReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d'):
                running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
                scale_factor = m.bn.weight / running_std
                if m.bias is not None:
                    m.bn.running_mean -= m.bias
                    m.bias = None
                m.bn.running_mean *= scale_factor
                m.bn.weight /= scale_factor
                m.forward = conv_fused_wrapper(m, scale_factor)

        def get_pre_hook(acps, indices):
            def pre_hook(module, input):
                new_input = list(input)
                for acp, idx in zip(acps, indices):
                    new_input[idx] = acp(new_input[idx])
                return tuple(new_input)

            return pre_hook

        def get_hook_func(acps, indices, orig_func):
            def pre_hook(*input, **kwargs):
                new_input = list(input)
                for acp, idx in zip(acps, indices):
                    if orig_func.__name__ == 'cat':
                        new_input[0][idx] = acp(new_input[0][idx])
                    else:
                        new_input[idx] = acp(new_input[idx])
                input = tuple(new_input)
                return orig_func(*input, **kwargs)

            return pre_hook

        if self.backend == 'onnx':
            end_mod_dict = {}
            for start_mod, end_mod, idx in self.swap_nodes:
                end_mod_dict.setdefault(end_mod, ([], []))
                end_mod_dict[end_mod][0].append(start_mod)
                end_mod_dict[end_mod][1].append(idx)

            for end_mod, (start_mods, indices) in end_mod_dict.items():
                acps = []
                for start_mod in start_mods:
                    if inspect.isroutine(start_mod) and isinstance(start_mod.__self__, nnq.FloatFunctional):
                        acp = start_mod.__self__.activation_post_process
                    else:
                        acp = start_mod.activation_post_process
                    acps.append(acp)

                if inspect.isroutine(end_mod) and isinstance(end_mod.__self__, nnq.FloatFunctional):
                    ff = end_mod.__self__

                    for i in range(len(indices)):
                        indices[i] -= 1

                    ff.cat = get_hook_func(acps, indices, ff.cat)
                    ff.add = get_hook_func(acps, indices, ff.add)
                    ff.mul = get_hook_func(acps, indices, ff.mul)
                    ff.add_scalar = get_hook_func(acps, indices, ff.add_scalar)
                    ff.mul_scalar = get_hook_func(acps, indices, ff.mul_scalar)
                    ff.add_relu = get_hook_func(acps, indices, ff.add_relu)
                else:
                    assert isinstance(
                        end_mod, nn.Module
                    ), "Only end nodes with `nn.Module` are supported duing module swapping"

                    end_mod.register_forward_pre_hook(get_pre_hook(acps, indices))

        q_model.apply(torch_q.disable_observer)