tinynn/util/quantization_analysis_util.py (235 lines of code) (raw):

import os from matplotlib import pyplot as plt from typing import List import torch import torch.nn as nn from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from .train_util import get_logger, get_module_device log = get_logger(__name__, 'INFO') def sqnr(x: torch.Tensor, y: torch.Tensor): Ps = torch.norm(x) Pn = torch.norm(x - y) return (20 * torch.log10(Ps / Pn)).item() def cosine(x: torch.Tensor, y: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: """calulate the cosine similarity between x and y""" if x.shape != y.shape: raise ValueError(f'Can not compute loss for tensors with different shape. ({x.shape} and {y.shape})') reduction = str(reduction).lower() if x.ndim == 1: x = x.unsqueeze(0) y = y.unsqueeze(0) x = x.flatten(start_dim=1).float() y = y.flatten(start_dim=1).float() cosine_sim = torch.cosine_similarity(x, y, dim=-1) if reduction == 'mean': return torch.mean(cosine_sim) elif reduction == 'sum': return torch.sum(cosine_sim) elif reduction == 'none': return cosine_sim else: raise ValueError(f'Cosine similarity do not supported {reduction} method.') METRIC_DICT = { 'cosine': cosine, 'sqnr': sqnr, } def error_print(metric, q_errors_activ, q_errors_weight, sort_num): logs = [] if len(q_errors_weight) > 0: logs.append('') logs.append(f'Weights ({metric} sorted {sort_num}):') for n, m, e in q_errors_weight: logs.append(f'{n:40} {metric}: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}') if len(q_errors_activ) > 0: logs.append('') logs.append(f'Activations ({metric} sorted {sort_num}):') for n, m, e in q_errors_activ: logs.append(f'{n:50} {metric}: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}') if len(q_errors_weight) == 0 and len(q_errors_activ) == 0: logs.append('') logs.append('All good!') if len(logs) > 0: logs.insert(0, 'Quantization error report:') logs.append('') full_log = '\n'.join(logs) log.warning(full_log) def layer_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine', sort_num: float = 20): """Generates the layerwise quant error report using the given metric, the q_model need to be qat_prepared. Args: q_model: The quant prepared model dummy_input: A viable input to the model metric: Metrics for measuring the error of floating point tensor and quantized tensor. Default to be 'cosine', optional 'sqnr'. sort_num : The smallest sort_num layer0 on given metric. Defaults to 20 """ if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel): model = q_model.module else: model = q_model metric_fn = METRIC_DICT[metric] train_flag = model.training model.eval() with torch.no_grad(): modules_list = {} names_list = {} float_results = {} hooks = [] def forward_hook(module, input, output): name = names_list[module] float_results[name] = input fake_quant_enabled_dict = {} observer_enabled_dict = {} for n, m in model.named_modules(): if isinstance(m, torch.quantization.FakeQuantize): names_list[m] = n modules_list[n] = m fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone() observer_enabled_dict[m] = m.observer_enabled.clone() hooks.append(m.register_forward_hook(forward_hook)) if len(modules_list) == 0: log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?') model.apply(torch.quantization.disable_fake_quant) model.apply(torch.quantization.disable_observer) device = get_module_device(model) if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) else: log.error(f'Unsupported type {type(dummy_input)} for dummy input') assert False for i in range(len(actual_input)): dummy_input = actual_input[i] if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) with torch.no_grad(): model(*actual_input) for h in hooks: h.remove() hooks.clear() for m, v in fake_quant_enabled_dict.items(): m.fake_quant_enabled = v q_errors_weight = [] q_errors_activ = [] while len(float_results) > 0: n, f = float_results.popitem() mod = modules_list[n] with torch.no_grad(): q = mod(*f) loss = metric_fn(f[0], q) actual_n = '.'.join(n.split('.')[:-1]) if n.endswith('.weight_fake_quant'): q_errors_weight.append((actual_n, mod, loss)) else: q_errors_activ.append((actual_n, mod, loss)) q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2]) q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2]) q_errors_weight = q_errors_weight[:sort_num] q_errors_activ = q_errors_activ[:sort_num] error_print(metric, q_errors_activ, q_errors_weight, sort_num) for m, v in observer_enabled_dict.items(): m.observer_enabled = v if train_flag: model.train() def graph_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine'): """Generates the cumulative quant error report using the given metric, the q_model need to be qat_prepared. Args: q_model: The quant prepared model. dummy_input: A viable input to the model metric: Metrics for measuring the error of floating point tensor and quantized tensor. Default to be 'cosine', optional 'sqnr'. """ if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel): model = q_model.module else: model = q_model metric_fn = METRIC_DICT[metric] train_flag = model.training model.eval() with torch.no_grad(): modules_list = {} names_list = {} results = {} hooks = [] def forward_hook(module, input, output): name = names_list[module] results[name] = input fake_quant_enabled_dict = {} observer_enabled_dict = {} for n, m in model.named_modules(): if isinstance(m, torch.quantization.FakeQuantize): names_list[m] = n modules_list[n] = m fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone() observer_enabled_dict[m] = m.observer_enabled.clone() hooks.append(m.register_forward_hook(forward_hook)) model.apply(torch.quantization.disable_fake_quant) model.apply(torch.quantization.disable_observer) if len(modules_list) == 0: log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?') device = get_module_device(model) if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) else: log.error(f'Unsupported type {type(dummy_input)} for dummy input') assert False for i in range(len(actual_input)): dummy_input = actual_input[i] if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) model(*actual_input) # Restore fake-quantize and record activation with quantization error. for m, v in fake_quant_enabled_dict.items(): m.fake_quant_enabled = v float_results = results results = {} model(*actual_input) for h in hooks: h.remove() hooks.clear() q_errors_activ = [] for name, f_tensor in float_results.items(): assert name in results, f'{name} not in results' actual_n = '.'.join(name.split('.')[:-1]) loss = metric_fn(f_tensor[0], results[name][0]) if not name.endswith('.weight_fake_quant'): q_errors_activ.append((actual_n, modules_list[name], loss)) error_print(metric, q_errors_activ, [], '') for m, v in observer_enabled_dict.items(): m.observer_enabled = v if train_flag: model.train() def get_weight_dis( model: nn.Module, unique_name_list: List[str] = None, nbins=256, save_path: str = 'out', threshold=20, fig_size=(7, 7), ): """Draw the weight distribution of model Args: model: We recommend use ptq-prepared model to draw fused weight distribution unique_name_list: You can set the layer which you want to get distribution, default to all layer of model nbins: Bins of distribution, default to be 256 save_path: Weight distribution fig weill saved at "[save_path]/weight_distribution" threshold: The threshold of weight range to used to prompt anomalies fig_size: Set fig size """ with torch.no_grad(): save_dir = os.path.join(save_path, 'weight_distribution') log.info(f"jpgs will saved at {save_dir}") if not os.path.exists(save_dir): os.makedirs(save_dir) warning_layer = dict() for name, mod in model.named_modules(): if (not hasattr(mod, 'weight')) or isinstance(mod, nn.BatchNorm2d): continue if unique_name_list is None or name in unique_name_list: op_type = type(mod).__name__ x = mod.weight.cpu() if op_type in dir(torch.nn.intrinsic.qat) and hasattr(mod, 'bn'): # Use torch.nn.util.fusion.fuse_conv_bn_weights to caculate bn_fused conv's weight. bn_var_rsqrt = torch.rsqrt(mod.bn.running_var + mod.bn.eps) x = mod.weight * (mod.bn.weight * bn_var_rsqrt).reshape([-1] + [1] * (len(mod.weight.shape) - 1)) x = x.cpu() y = torch.histc(x, nbins) x_min = torch.min(x) x_max = torch.max(x) if x_max - x_min > threshold: warning_layer[name] = (op_type, float(x_min), float(x_max)) bin_width = (x_max - x_min) / nbins x_s = [x_min + (idx + 0.5) * bin_width for idx in range(nbins)] fig, ax = plt.subplots(figsize=fig_size) ax.set_yscale('log') ax.plot(x_s, y.detach().numpy()) ax.set_title(f'Op_uname: {name}[{op_type}]') ax.set_xlabel(f'Range:[{x_min:.4f},{x_max:.4f}]') ax.set_ylabel('Count') save_path = os.path.join(save_dir, f'{name}.jpg') plt.savefig(save_path) plt.cla() if warning_layer: log_str = f'\n---------the layer weight range length greater than {threshold}---------\n' for k, v in warning_layer.items(): log_str += f'{k}, {v}\n' log_str += '---------------------------------------------------------------' log.warning(log_str)