def layer_error_analysis()

in tinynn/util/quantization_analysis_util.py [0:0]


def layer_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine', sort_num: float = 20):
    """Generates the layerwise quant error report using the given metric, the q_model need to be qat_prepared.

    Args:
        q_model: The quant prepared model
        dummy_input: A viable input to the model
        metric: Metrics for measuring the error of floating point tensor and quantized tensor.
                Default to be 'cosine', optional 'sqnr'.
        sort_num : The smallest sort_num layer0 on given metric. Defaults to 20
    """

    if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
        model = q_model.module
    else:
        model = q_model
    metric_fn = METRIC_DICT[metric]

    train_flag = model.training
    model.eval()
    with torch.no_grad():
        modules_list = {}
        names_list = {}
        float_results = {}
        hooks = []

        def forward_hook(module, input, output):
            name = names_list[module]
            float_results[name] = input

        fake_quant_enabled_dict = {}
        observer_enabled_dict = {}
        for n, m in model.named_modules():
            if isinstance(m, torch.quantization.FakeQuantize):
                names_list[m] = n
                modules_list[n] = m

                fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
                observer_enabled_dict[m] = m.observer_enabled.clone()

                hooks.append(m.register_forward_hook(forward_hook))

        if len(modules_list) == 0:
            log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')

        model.apply(torch.quantization.disable_fake_quant)
        model.apply(torch.quantization.disable_observer)

        device = get_module_device(model)

        if type(dummy_input) is torch.Tensor:
            actual_input = [dummy_input]
        elif isinstance(dummy_input, (tuple, list)):
            actual_input = list(dummy_input)
        else:
            log.error(f'Unsupported type {type(dummy_input)} for dummy input')
            assert False

        for i in range(len(actual_input)):
            dummy_input = actual_input[i]
            if type(dummy_input) is torch.Tensor:
                if dummy_input.device != device:
                    actual_input[i] = dummy_input.to(device)

        with torch.no_grad():
            model(*actual_input)

        for h in hooks:
            h.remove()
        hooks.clear()

        for m, v in fake_quant_enabled_dict.items():
            m.fake_quant_enabled = v

        q_errors_weight = []
        q_errors_activ = []
        while len(float_results) > 0:
            n, f = float_results.popitem()
            mod = modules_list[n]
            with torch.no_grad():
                q = mod(*f)
                loss = metric_fn(f[0], q)
            actual_n = '.'.join(n.split('.')[:-1])
            if n.endswith('.weight_fake_quant'):
                q_errors_weight.append((actual_n, mod, loss))
            else:
                q_errors_activ.append((actual_n, mod, loss))

        q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2])
        q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2])

        q_errors_weight = q_errors_weight[:sort_num]
        q_errors_activ = q_errors_activ[:sort_num]

        error_print(metric, q_errors_activ, q_errors_weight, sort_num)

        for m, v in observer_enabled_dict.items():
            m.observer_enabled = v

    if train_flag:
        model.train()