in tinynn/graph/quantization/quantizer.py [0:0]
def error_analysis(self, qat_model: nn.Module, dummy_input, threshold: float = 20.0):
"""Generates the QAT error report using the SQNR metric
Args:
qat_model: The QAT model
dummy_input: A viable input to the model
threshold (float): The threshold of SQNR. Defaults to 20.0
"""
if isinstance(qat_model, DataParallel) or isinstance(qat_model, DistributedDataParallel):
model = qat_model.module
else:
model = qat_model
modules_list = {}
names_list = {}
float_results = {}
hooks = []
def forward_hook(module, input, output):
name = names_list[module]
float_results[name] = input
fake_quant_enabled_dict = {}
observer_enabled_dict = {}
for n, m in model.named_modules():
if isinstance(m, torch.quantization.FakeQuantize):
names_list[m] = n
modules_list[n] = m
fake_quant_enabled_dict[m] = m.fake_quant_enabled.clone()
observer_enabled_dict[m] = m.observer_enabled.clone()
hooks.append(m.register_forward_hook(forward_hook))
if len(modules_list) == 0:
log.warning('No FakeQuantize modules found. Are you sure you passed in a QAT model?')
return
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.disable_observer)
device = get_module_device(model)
if type(dummy_input) is torch.Tensor:
actual_input = [dummy_input]
elif isinstance(dummy_input, (tuple, list)):
actual_input = list(dummy_input)
else:
log.error(f'Unsupported type {type(dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
if dummy_input.device != device:
actual_input[i] = dummy_input.to(device)
model.eval()
with torch.no_grad():
model(*actual_input)
for h in hooks:
h.remove()
hooks.clear()
for m, v in fake_quant_enabled_dict.items():
m.fake_quant_enabled = v
def sqnr(x, y):
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return (20 * torch.log10(Ps / Pn)).item()
q_errors_weight = []
q_errors_activ = []
while len(float_results) > 0:
n, f = float_results.popitem()
mod = modules_list[n]
with torch.no_grad():
q = mod(*f)
loss = sqnr(f[0], q)
actual_n = '.'.join(n.split('.')[:-1])
if loss <= threshold:
if n.endswith('.weight_fake_quant'):
q_errors_weight.append((actual_n, mod, loss))
else:
q_errors_activ.append((actual_n, mod, loss))
q_errors_weight = sorted(q_errors_weight, key=lambda x: x[2])
q_errors_activ = sorted(q_errors_activ, key=lambda x: x[2])
logs = []
if len(q_errors_weight) > 0:
logs.append('')
logs.append(f'Weights (SQNR <= {threshold}):')
for n, m, e in q_errors_weight:
logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')
if len(q_errors_activ) > 0:
logs.append('')
logs.append(f'Activations (SQNR <= {threshold}):')
for n, m, e in q_errors_activ:
logs.append(f'{n} SQNR: {e:.4f}, scale: {m.scale.item():.4f}, zero_point: {m.zero_point.item()}')
if len(q_errors_weight) == 0 and len(q_errors_activ) == 0:
logs.append('')
logs.append('All good!')
if len(logs) > 0:
logs.insert(0, 'Quantization error report:')
logs.append('')
full_log = '\n'.join(logs)
log.warning(full_log)
for m, v in observer_enabled_dict.items():
m.observer_enabled = v
model.train()