def graph_error_analysis()

in tinynn/util/quantization_analysis_util.py [0:0]


def graph_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine'):
    """Generates the cumulative quant error report using the given metric, the q_model need to be qat_prepared.

    Args:
        q_model: The quant prepared model.
        dummy_input: A viable input to the model
        metric: Metrics for measuring the error of floating point tensor and quantized tensor.
                Default to be 'cosine', optional 'sqnr'.
    """
    if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
        model = q_model.module
    else:
        model = q_model
    metric_fn = METRIC_DICT[metric]

    train_flag = model.training
    model.eval()

    with torch.no_grad():
        modules_list = {}
        names_list = {}
        results = {}
        hooks = []

        def forward_hook(module, input, output):
            name = names_list[module]
            results[name] = input

        fake_quant_enabled_dict = {}
        observer_enabled_dict = {}
        for n, m in model.named_modules():
            if isinstance(m, torch.quantization.FakeQuantize):
                names_list[m] = n
                modules_list[n] = m

                fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
                observer_enabled_dict[m] = m.observer_enabled.clone()

                hooks.append(m.register_forward_hook(forward_hook))

        model.apply(torch.quantization.disable_fake_quant)
        model.apply(torch.quantization.disable_observer)

        if len(modules_list) == 0:
            log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')

        device = get_module_device(model)

        if type(dummy_input) is torch.Tensor:
            actual_input = [dummy_input]
        elif isinstance(dummy_input, (tuple, list)):
            actual_input = list(dummy_input)
        else:
            log.error(f'Unsupported type {type(dummy_input)} for dummy input')
            assert False

        for i in range(len(actual_input)):
            dummy_input = actual_input[i]
            if type(dummy_input) is torch.Tensor:
                if dummy_input.device != device:
                    actual_input[i] = dummy_input.to(device)

        model(*actual_input)

        # Restore fake-quantize and record activation with quantization error.
        for m, v in fake_quant_enabled_dict.items():
            m.fake_quant_enabled = v
        float_results = results
        results = {}

        model(*actual_input)

        for h in hooks:
            h.remove()
        hooks.clear()

        q_errors_activ = []
        for name, f_tensor in float_results.items():
            assert name in results, f'{name} not in results'
            actual_n = '.'.join(name.split('.')[:-1])
            loss = metric_fn(f_tensor[0], results[name][0])
            if not name.endswith('.weight_fake_quant'):
                q_errors_activ.append((actual_n, modules_list[name], loss))

        error_print(metric, q_errors_activ, [], '')

        for m, v in observer_enabled_dict.items():
            m.observer_enabled = v

    if train_flag:
        model.train()