in tinynn/util/quantization_analysis_util.py [0:0]
def layer_error_analysis(q_model: nn.Module, dummy_input, metric: str = 'cosine', sort_num: float = 20):
"""Generates the layerwise quant error report using the given metric, the q_model need to be qat_prepared.
Args:
q_model: The quant prepared model
dummy_input: A viable input to the model
metric: Metrics for measuring the error of floating point tensor and quantized tensor.
Default to be 'cosine', optional 'sqnr'.
sort_num : The smallest sort_num layer0 on given metric. Defaults to 20
"""
if isinstance(q_model, DataParallel) or isinstance(q_model, DistributedDataParallel):
model = q_model.module
else:
model = q_model
metric_fn = METRIC_DICT[metric]
train_flag = model.training
model.eval()
with torch.no_grad():
modules_list = {}
names_list = {}
float_results = {}
hooks = []
def forward_hook(module, input, output):
name = names_list[module]
float_results[name] = input
fake_quant_enabled_dict = {}
observer_enabled_dict = {}
for n, m in model.named_modules():
if isinstance(m, torch.quantization.FakeQuantize):
names_list[m] = n
modules_list[n] = m
fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
observer_enabled_dict[m] = m.observer_enabled.clone()
hooks.append(m.register_forward_hook(forward_hook))
if len(modules_list) == 0:
log.warning('No FakeQuantize modules found. Are you sure you had prepared your model?')
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.disable_observer)
device = get_module_device(model)
if type(dummy_input) is torch.Tensor:
actual_input = [dummy_input]
elif isinstance(dummy_input, (tuple, list)):
actual_input = list(dummy_input)
else:
log.error(f'Unsupported type {type(dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
if dummy_input.device != device:
actual_input[i] = dummy_input.to(device)
with torch.no_grad():
model(*actual_input)
for h in hooks:
h.remove()
hooks.clear()
for m, v in fake_quant_enabled_dict.items():
m.fake_quant_enabled = v
q_errors_weight = []
q_errors_activ = []
while len(float_results) > 0:
n, f = float_results.popitem()
mod = modules_list[n]
with torch.no_grad():
q = mod(*f)
loss = metric_fn(f[0], q)
actual_n = '.'.join(n.split('.')[:-1])
if n.endswith('.weight_fake_quant'):
q_errors_weight.append((actual_n, mod, loss))
else:
q_errors_activ.append((actual_n, mod, loss))
q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2])
q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2])
q_errors_weight = q_errors_weight[:sort_num]
q_errors_activ = q_errors_activ[:sort_num]
error_print(metric, q_errors_activ, q_errors_weight, sort_num)
for m, v in observer_enabled_dict.items():
m.observer_enabled = v
if train_flag:
model.train()