def error_analysis()

in tinynn/graph/quantization/quantizer.py [0:0]


    def error_analysis(self, qat_model: nn.Module, dummy_input, threshold: float = 20.0):
        """Generates the QAT error report using the SQNR metric

        Args:
            qat_model: The QAT model
            dummy_input: A viable input to the model
            threshold (float): The threshold of SQNR. Defaults to 20.0
        """

        if isinstance(qat_model, DataParallel) or isinstance(qat_model, DistributedDataParallel):
            model = qat_model.module
        else:
            model = qat_model

        modules_list = {}
        names_list = {}

        float_results = {}
        hooks = []

        def forward_hook(module, input, output):
            name = names_list[module]
            float_results[name] = input

        fake_quant_enabled_dict = {}
        observer_enabled_dict = {}
        for n, m in model.named_modules():
            if isinstance(m, torch.quantization.FakeQuantize):
                names_list[m] = n
                modules_list[n] = m

                fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
                observer_enabled_dict[m] = m.observer_enabled.clone()

                hooks.append(m.register_forward_hook(forward_hook))

        if len(modules_list) == 0:
            log.warning('No FakeQuantize modules found. Are you sure you passed in a QAT model?')
            return

        model.apply(torch.quantization.disable_fake_quant)
        model.apply(torch.quantization.disable_observer)

        device = get_module_device(model)

        if type(dummy_input) is torch.Tensor:
            actual_input = [dummy_input]
        elif isinstance(dummy_input, (tuple, list)):
            actual_input = list(dummy_input)
        else:
            log.error(f'Unsupported type {type(dummy_input)} for dummy input')
            assert False

        for i in range(len(actual_input)):
            dummy_input = actual_input[i]
            if type(dummy_input) is torch.Tensor:
                if dummy_input.device != device:
                    actual_input[i] = dummy_input.to(device)

        model.eval()

        with torch.no_grad():
            model(*actual_input)

        for h in hooks:
            h.remove()
        hooks.clear()

        for m, v in fake_quant_enabled_dict.items():
            m.fake_quant_enabled = v

        def sqnr(x, y):
            Ps = torch.norm(x)
            Pn = torch.norm(x - y)
            return (20 * torch.log10(Ps / Pn)).item()

        q_errors_weight = []
        q_errors_activ = []
        while len(float_results) > 0:
            n, f = float_results.popitem()
            mod = modules_list[n]
            with torch.no_grad():
                q = mod(*f)
                loss = sqnr(f[0], q)
            actual_n = '.'.join(n.split('.')[:-1])
            if loss <= threshold:
                if n.endswith('.weight_fake_quant'):
                    q_errors_weight.append((actual_n, mod, loss))
                else:
                    q_errors_activ.append((actual_n, mod, loss))

        q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2])
        q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2])

        logs = []
        if len(q_errors_weight) > 0:
            logs.append('')
            logs.append(f'Weights (SQNR <= {threshold}):')
            for n, m, e in q_errors_weight:
                logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')

        if len(q_errors_activ) > 0:
            logs.append('')
            logs.append(f'Activations (SQNR <= {threshold}):')
            for n, m, e in q_errors_activ:
                logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')

        if len(q_errors_weight) == 0 and len(q_errors_activ) == 0:
            logs.append('')
            logs.append('All good!')

        if len(logs) > 0:
            logs.insert(0, 'Quantization error report:')
            logs.append('')

            full_log = '\n'.join(logs)
            log.warning(full_log)

        for m, v in observer_enabled_dict.items():
            m.observer_enabled = v

        model.train()